@@ -163,30 +163,6 @@ def from_prequantized(cls, data, quantized_stats, requires_grad=False, device='c
163163 self .compress_statistics = self .quant_state .nested
164164 self .quant_type = self .quant_state .quant_type
165165 return self
166-
167- # @classmethod
168- # def from_state_dict(cls, state_dict, prefix="", requires_grad=False):
169- # data = state_dict.pop(prefix.rstrip('.'))
170-
171- # # extracting components for QuantState from state_dict
172- # qs_dict = {}
173- # for k, v in state_dict.items():
174- # if k.replace(prefix, '').split('.')[0] in QuantState.valid_qs_keys:
175- # qs_dict[k] = v
176- # state_dict = {k: v for k, v in state_dict.items() if k not in qs_dict}
177- # qs_dict = {k.replace(prefix, ''): v for k, v in qs_dict.items()}
178-
179- # if data.device.type != "cuda":
180- # raise ValueError(f"`data.device.type` must be 'cuda', detected {data.device.type}")
181-
182- # cls.requires_grad = requires_grad
183- # cls.quant_state = QuantState.from_dict(qs_dict=qs_dict, device=data.device)
184- # cls.blocksize = cls.quant_state.blocksize # this attribute can be deprecated - it duplicates same one in quant_state
185- # cls.compress_statistics = cls.quant_state.nested # this attribute can be deprecated - it duplicates quant_state.nested
186- # cls.quant_type = cls.quant_state.quant_type # this attribute can be deprecated - it duplicates same one in quant_state
187-
188- # self = torch.Tensor._make_subclass(cls, data=data.to(data.device))
189- # return self, state_dict
190166
191167 def cuda (self , device ):
192168 w = self .data .contiguous ().half ().cuda (device )
@@ -227,7 +203,7 @@ def to(self, *args, **kwargs):
227203
228204class Linear4bit (nn .Linear ):
229205
230- def __init__ (self , input_features , output_features , bias = True , compute_dtype = None , compress_statistics = True , quant_type = 'fp4' ,device = None ):
206+ def __init__ (self , input_features , output_features , bias = True , compute_dtype = None , compress_statistics = True , quant_type = 'fp4' , device = None ):
231207 super ().__init__ (input_features , output_features , bias , device )
232208 self .weight = Params4bit (self .weight .data , requires_grad = False , compress_statistics = compress_statistics , quant_type = quant_type )
233209 # self.persistent_buffers = [] # TODO consider as way to save quant state
@@ -261,18 +237,6 @@ def _save_to_state_dict(self, destination, prefix, keep_vars):
261237 for k , v in self .weight .quant_state .as_dict (packed = True ).items ():
262238 destination [prefix + "weight." + k ] = v if keep_vars else v .detach ()
263239
264- # def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
265- # missing_keys, unexpected_keys, error_msgs):
266- # # Note: super()._load_from_state_dict() is not called here intentionally.
267- # if self.bias is not None:
268- # bias_data = state_dict.pop(prefix + "bias", None)
269- # self.bias.data = bias_data.to(self.bias.data.device)
270-
271- # self.weight, state_dict = bnb.nn.Params4bit.from_state_dict(
272- # state_dict, prefix=prefix + "weight" + ".", requires_grad=False
273- # )
274- # unexpected_keys.extend(state_dict.keys())
275-
276240 def forward (self , x : torch .Tensor ):
277241 # weights are cast automatically as Int8Params, but the bias has to be cast manually
278242 if self .bias is not None and self .bias .dtype != x .dtype :
@@ -295,10 +259,12 @@ def forward(self, x: torch.Tensor):
295259
296260 return out
297261
262+
298263class LinearFP4 (Linear4bit ):
299- def __init__ (self , input_features , output_features , bias = True , compute_dtype = None , compress_statistics = True ,device = None ):
264+ def __init__ (self , input_features , output_features , bias = True , compute_dtype = None , compress_statistics = True , device = None ):
300265 super ().__init__ (input_features , output_features , bias , compute_dtype , compress_statistics , 'fp4' , device )
301266
267+
302268class LinearNF4 (Linear4bit ):
303269 ''' Implements the NF4 data type.
304270
@@ -310,7 +276,7 @@ class LinearNF4(Linear4bit):
310276 Implementation of the NF4 data type in bitsandbytes can be found in the `create_normal_map` function in
311277 the `functional.py` file: https://github.com/TimDettmers/bitsandbytes/blob/main/bitsandbytes/functional.py#L236.
312278 '''
313- def __init__ (self , input_features , output_features , bias = True , compute_dtype = None , compress_statistics = True ,device = None ):
279+ def __init__ (self , input_features , output_features , bias = True , compute_dtype = None , compress_statistics = True , device = None ):
314280 super ().__init__ (input_features , output_features , bias , compute_dtype , compress_statistics , 'nf4' , device )
315281
316282
0 commit comments