@@ -651,7 +651,7 @@ def from_dict(cls, qs_dict: Dict[str, Any], device: torch.device) -> 'QuantState
651651 def as_dict (self , packed = False ):
652652 """
653653 returns dict of tensors and strings to use in serialization via _save_to_state_dict()
654- param: packed -- returns dict[str, torch.Tensor] for state_dict
654+ param: packed -- returns dict[str, torch.Tensor] for state_dict fit for safetensors saving
655655 """
656656 qs_dict = {
657657 'quant_type' : self .quant_type ,
@@ -665,13 +665,14 @@ def as_dict(self, packed=False):
665665 qs_dict .update ({
666666 'nested_absmax' : self .state2 .absmax ,
667667 'nested_blocksize' : self .state2 .blocksize ,
668- 'nested_quant_map' : self .state2 .code ,
668+ 'nested_quant_map' : self .state2 .code . clone (), # un-shared to avoid restoring it after shared tensors are removed by safetensors
669669 'nested_dtype' : str (self .state2 .dtype ).strip ('torch.' ),
670670 'nested_offset' : self .offset .item (),
671671 })
672672 if not packed :
673673 return qs_dict
674674
675+ # packed format allows serialization of non-tensor components, critical for saving in safetensors format
675676 qs_packed_dict = {k : v for k , v in qs_dict .items () if isinstance (v , torch .Tensor )}
676677 non_tensor_dict = {k : v for k , v in qs_dict .items () if not isinstance (v , torch .Tensor )}
677678 qs_packed_dict ["quant_state." + "bitsandbytes__" + self .quant_type ] = pack_dict_to_tensor (non_tensor_dict )
0 commit comments