22#
33# This source code is licensed under the MIT license found in the
44# LICENSE file in the root directory of this source tree.
5- from typing import Optional , TypeVar , Union , overload
5+ from typing import Any , Dict , Optional , TypeVar , Union , overload
66
77import warnings
88import torch
@@ -139,9 +139,10 @@ def forward(self, input: Tensor) -> Tensor:
139139
140140 return emb
141141
142+
142143class Params4bit (torch .nn .Parameter ):
143144
144- def __new__ (cls , data = None , requires_grad = True , quant_state = None , blocksize = 64 , compress_statistics = True , quant_type = 'fp4' ):
145+ def __new__ (cls , data : Optional [ torch . Tensor ] = None , requires_grad = True , quant_state : QuantState = None , blocksize : int = 64 , compress_statistics : bool = True , quant_type : str = 'fp4' ) -> "Params4bit" :
145146 if data is None :
146147 data = torch .empty (0 )
147148
@@ -152,27 +153,16 @@ def __new__(cls, data=None, requires_grad=True, quant_state=None, blocksize=64,
152153 self .quant_state = quant_state
153154 self .data = data
154155 return self
155-
156+
156157 @classmethod
157- def from_state_dict (cls , state_dict , prefix = "" , requires_grad = False ):
158- data = state_dict .pop (prefix .rstrip ('.' ))
159-
160- # extracting components for QuantState from state_dict
161- qs_dict = {}
162- for k , v in state_dict .items ():
163- if k .replace (prefix , '' ).split ('.' )[0 ] in QuantState .valid_qs_keys :
164- qs_dict [k ] = v
165- state_dict = {k : v for k , v in state_dict .items () if k not in qs_dict }
166- qs_dict = {k .replace (prefix , '' ): v for k , v in qs_dict .items ()}
167-
168- if data .device .type != "cuda" :
169- raise ValueError (f"`data.device.type` must be 'cuda', detected { data .device .type } " )
170-
171- cls .requires_grad = requires_grad ,
172- cls .quant_state = QuantState .from_dict (qs_dict = qs_dict , device = data .device )
173-
174- self = torch .Tensor ._make_subclass (cls , data = data .to (data .device ))
175- return self , state_dict
158+ def from_prequantized (cls , data : torch .Tensor , quantized_stats : Dict [str , Any ], requires_grad : bool = False , device = 'cuda' , ** kwargs ) -> "Params4bit" :
159+ self = torch .Tensor ._make_subclass (cls , data .to (device ))
160+ self .requires_grad = requires_grad
161+ self .quant_state = QuantState .from_dict (qs_dict = quantized_stats , device = device )
162+ self .blocksize = self .quant_state .blocksize
163+ self .compress_statistics = self .quant_state .nested
164+ self .quant_type = self .quant_state .quant_type
165+ return self
176166
177167 def cuda (self , device ):
178168 w = self .data .contiguous ().half ().cuda (device )
@@ -204,15 +194,16 @@ def to(self, *args, **kwargs):
204194 self .quant_state .to (device )
205195
206196 new_param = Params4bit (super ().to (device = device , dtype = dtype , non_blocking = non_blocking ),
207- requires_grad = self .requires_grad , quant_state = self .quant_state ,
197+ requires_grad = self .requires_grad , quant_state = self .quant_state ,
208198 blocksize = self .blocksize , compress_statistics = self .compress_statistics ,
209199 quant_type = self .quant_type )
210200
211201 return new_param
212202
203+
213204class Linear4bit (nn .Linear ):
214-
215- def __init__ (self , input_features , output_features , bias = True , compute_dtype = None , compress_statistics = True , quant_type = 'fp4' ,device = None ):
205+
206+ def __init__ (self , input_features , output_features , bias = True , compute_dtype = None , compress_statistics = True , quant_type = 'fp4' , device = None ):
216207 super ().__init__ (input_features , output_features , bias , device )
217208 self .weight = Params4bit (self .weight .data , requires_grad = False , compress_statistics = compress_statistics , quant_type = quant_type )
218209 # self.persistent_buffers = [] # TODO consider as way to save quant state
@@ -246,18 +237,6 @@ def _save_to_state_dict(self, destination, prefix, keep_vars):
246237 for k , v in self .weight .quant_state .as_dict (packed = True ).items ():
247238 destination [prefix + "weight." + k ] = v if keep_vars else v .detach ()
248239
249- def _load_from_state_dict (self , state_dict , prefix , local_metadata , strict ,
250- missing_keys , unexpected_keys , error_msgs ):
251- # Note: super()._load_from_state_dict() is not called here intentionally.
252- if self .bias is not None :
253- bias_data = state_dict .pop (prefix + "bias" , None )
254- self .bias .data = bias_data .to (self .bias .data .device )
255-
256- self .weight , state_dict = bnb .nn .Params4bit .from_state_dict (
257- state_dict , prefix = prefix + "weight" + "." , requires_grad = False
258- )
259- unexpected_keys .extend (state_dict .keys ())
260-
261240 def forward (self , x : torch .Tensor ):
262241 # weights are cast automatically as Int8Params, but the bias has to be cast manually
263242 if self .bias is not None and self .bias .dtype != x .dtype :
@@ -280,10 +259,12 @@ def forward(self, x: torch.Tensor):
280259
281260 return out
282261
262+
283263class LinearFP4 (Linear4bit ):
284- def __init__ (self , input_features , output_features , bias = True , compute_dtype = None , compress_statistics = True ,device = None ):
264+ def __init__ (self , input_features , output_features , bias = True , compute_dtype = None , compress_statistics = True , device = None ):
285265 super ().__init__ (input_features , output_features , bias , compute_dtype , compress_statistics , 'fp4' , device )
286266
267+
287268class LinearNF4 (Linear4bit ):
288269 ''' Implements the NF4 data type.
289270
@@ -295,7 +276,7 @@ class LinearNF4(Linear4bit):
295276 Implementation of the NF4 data type in bitsandbytes can be found in the `create_normal_map` function in
296277 the `functional.py` file: https://github.com/TimDettmers/bitsandbytes/blob/main/bitsandbytes/functional.py#L236.
297278 '''
298- def __init__ (self , input_features , output_features , bias = True , compute_dtype = None , compress_statistics = True ,device = None ):
279+ def __init__ (self , input_features , output_features , bias = True , compute_dtype = None , compress_statistics = True , device = None ):
299280 super ().__init__ (input_features , output_features , bias , compute_dtype , compress_statistics , 'nf4' , device )
300281
301282
0 commit comments