@@ -154,14 +154,25 @@ def __new__(cls, data=None, requires_grad=True, quant_state=None, blocksize=64,
154154 return self
155155
156156 @classmethod
157- def from_prequantized (cls , data , quantized_stats , requires_grad = False , device = 'cuda' , ** kwargs ):
158- self = torch .Tensor ._make_subclass (cls , data .to (device ))
159- self .requires_grad = requires_grad
160- self .quant_state = QuantState .from_dict (qs_dict = quantized_stats , device = device )
161- self .blocksize = self .quant_state .blocksize
162- self .compress_statistics = self .quant_state .nested
163- self .quant_type = self .quant_state .quant_type
164- return self
157+ def from_state_dict (cls , state_dict , prefix = "" , requires_grad = False ):
158+ data = state_dict .pop (prefix .rstrip ('.' ))
159+
160+ # extracting components for QuantState from state_dict
161+ qs_dict = {}
162+ for k , v in state_dict .items ():
163+ if k .replace (prefix , '' ).split ('.' )[0 ] in QuantState .valid_qs_keys :
164+ qs_dict [k ] = v
165+ state_dict = {k : v for k , v in state_dict .items () if k not in qs_dict }
166+ qs_dict = {k .replace (prefix , '' ): v for k , v in qs_dict .items ()}
167+
168+ if data .device .type != "cuda" :
169+ raise ValueError (f"`data.device.type` must be 'cuda', detected { data .device .type } " )
170+
171+ cls .requires_grad = requires_grad ,
172+ cls .quant_state = QuantState .from_dict (qs_dict = qs_dict , device = data .device )
173+
174+ self = torch .Tensor ._make_subclass (cls , data = data .to (data .device ))
175+ return self , state_dict
165176
166177 def cuda (self , device ):
167178 w = self .data .contiguous ().half ().cuda (device )
@@ -200,9 +211,11 @@ def to(self, *args, **kwargs):
200211 return new_param
201212
202213class Linear4bit (nn .Linear ):
214+
203215 def __init__ (self , input_features , output_features , bias = True , compute_dtype = None , compress_statistics = True , quant_type = 'fp4' ,device = None ):
204216 super ().__init__ (input_features , output_features , bias , device )
205217 self .weight = Params4bit (self .weight .data , requires_grad = False , compress_statistics = compress_statistics , quant_type = quant_type )
218+ # self.persistent_buffers = [] # TODO consider as way to save quant state
206219 self .compute_dtype = compute_dtype
207220 self .compute_type_is_set = False
208221
@@ -233,6 +246,18 @@ def _save_to_state_dict(self, destination, prefix, keep_vars):
233246 for k , v in self .weight .quant_state .as_dict (packed = True ).items ():
234247 destination [prefix + "weight." + k ] = v if keep_vars else v .detach ()
235248
249+ def _load_from_state_dict (self , state_dict , prefix , local_metadata , strict ,
250+ missing_keys , unexpected_keys , error_msgs ):
251+ # Note: super()._load_from_state_dict() is not called here intentionally.
252+ if self .bias is not None :
253+ bias_data = state_dict .pop (prefix + "bias" , None )
254+ self .bias .data = bias_data .to (self .bias .data .device )
255+
256+ self .weight , state_dict = bnb .nn .Params4bit .from_state_dict (
257+ state_dict , prefix = prefix + "weight" + "." , requires_grad = False
258+ )
259+ unexpected_keys .extend (state_dict .keys ())
260+
236261 def forward (self , x : torch .Tensor ):
237262 # weights are cast automatically as Int8Params, but the bias has to be cast manually
238263 if self .bias is not None and self .bias .dtype != x .dtype :
0 commit comments