@@ -155,28 +155,38 @@ def __new__(cls, data=None, requires_grad=True, quant_state=None, blocksize=64,
155155 return self
156156
157157 @classmethod
158- def from_state_dict (cls , state_dict , prefix = "" , requires_grad = False ):
159- data = state_dict .pop (prefix .rstrip ('.' ))
160-
161- # extracting components for QuantState from state_dict
162- qs_dict = {}
163- for k , v in state_dict .items ():
164- if k .replace (prefix , '' ).split ('.' )[0 ] in QuantState .valid_qs_keys :
165- qs_dict [k ] = v
166- state_dict = {k : v for k , v in state_dict .items () if k not in qs_dict }
167- qs_dict = {k .replace (prefix , '' ): v for k , v in qs_dict .items ()}
168-
169- if data .device .type != "cuda" :
170- raise ValueError (f"`data.device.type` must be 'cuda', detected { data .device .type } " )
171-
172- cls .requires_grad = requires_grad
173- cls .quant_state = QuantState .from_dict (qs_dict = qs_dict , device = data .device )
174- cls .blocksize = cls .quant_state .blocksize # this attribute can be deprecated - it duplicates same one in quant_state
175- cls .compress_statistics = cls .quant_state .nested # this attribute can be deprecated - it duplicates quant_state.nested
176- cls .quant_type = cls .quant_state .quant_type # this attribute can be deprecated - it duplicates same one in quant_state
177-
178- self = torch .Tensor ._make_subclass (cls , data = data .to (data .device ))
179- return self , state_dict
158+ def from_prequantized (cls , data , quantized_stats , requires_grad = False , device = 'cuda' , ** kwargs ):
159+ self = torch .Tensor ._make_subclass (cls , data .to (device ))
160+ self .requires_grad = requires_grad
161+ self .quant_state = QuantState .from_dict (qs_dict = quantized_stats , device = device )
162+ self .blocksize = self .quant_state .blocksize
163+ self .compress_statistics = self .quant_state .nested
164+ self .quant_type = self .quant_state .quant_type
165+ return self
166+
167+ # @classmethod
168+ # def from_state_dict(cls, state_dict, prefix="", requires_grad=False):
169+ # data = state_dict.pop(prefix.rstrip('.'))
170+
171+ # # extracting components for QuantState from state_dict
172+ # qs_dict = {}
173+ # for k, v in state_dict.items():
174+ # if k.replace(prefix, '').split('.')[0] in QuantState.valid_qs_keys:
175+ # qs_dict[k] = v
176+ # state_dict = {k: v for k, v in state_dict.items() if k not in qs_dict}
177+ # qs_dict = {k.replace(prefix, ''): v for k, v in qs_dict.items()}
178+
179+ # if data.device.type != "cuda":
180+ # raise ValueError(f"`data.device.type` must be 'cuda', detected {data.device.type}")
181+
182+ # cls.requires_grad = requires_grad
183+ # cls.quant_state = QuantState.from_dict(qs_dict=qs_dict, device=data.device)
184+ # cls.blocksize = cls.quant_state.blocksize # this attribute can be deprecated - it duplicates same one in quant_state
185+ # cls.compress_statistics = cls.quant_state.nested # this attribute can be deprecated - it duplicates quant_state.nested
186+ # cls.quant_type = cls.quant_state.quant_type # this attribute can be deprecated - it duplicates same one in quant_state
187+
188+ # self = torch.Tensor._make_subclass(cls, data=data.to(data.device))
189+ # return self, state_dict
180190
181191 def cuda (self , device ):
182192 w = self .data .contiguous ().half ().cuda (device )
@@ -251,17 +261,17 @@ def _save_to_state_dict(self, destination, prefix, keep_vars):
251261 for k , v in self .weight .quant_state .as_dict (packed = True ).items ():
252262 destination [prefix + "weight." + k ] = v if keep_vars else v .detach ()
253263
254- def _load_from_state_dict (self , state_dict , prefix , local_metadata , strict ,
255- missing_keys , unexpected_keys , error_msgs ):
256- # Note: super()._load_from_state_dict() is not called here intentionally.
257- if self .bias is not None :
258- bias_data = state_dict .pop (prefix + "bias" , None )
259- self .bias .data = bias_data .to (self .bias .data .device )
260-
261- self .weight , state_dict = bnb .nn .Params4bit .from_state_dict (
262- state_dict , prefix = prefix + "weight" + "." , requires_grad = False
263- )
264- unexpected_keys .extend (state_dict .keys ())
264+ # def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
265+ # missing_keys, unexpected_keys, error_msgs):
266+ # # Note: super()._load_from_state_dict() is not called here intentionally.
267+ # if self.bias is not None:
268+ # bias_data = state_dict.pop(prefix + "bias", None)
269+ # self.bias.data = bias_data.to(self.bias.data.device)
270+
271+ # self.weight, state_dict = bnb.nn.Params4bit.from_state_dict(
272+ # state_dict, prefix=prefix + "weight" + ".", requires_grad=False
273+ # )
274+ # unexpected_keys.extend(state_dict.keys())
265275
266276 def forward (self , x : torch .Tensor ):
267277 # weights are cast automatically as Int8Params, but the bias has to be cast manually
0 commit comments