Skip to content

Commit de44be1

Browse files
committed
unshared nested_quant_map
1 parent ffd46ce commit de44be1

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

bitsandbytes/functional.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -651,7 +651,7 @@ def from_dict(cls, qs_dict: Dict[str, Any], device: torch.device) -> 'QuantState
651651
def as_dict(self, packed=False):
652652
"""
653653
returns dict of tensors and strings to use in serialization via _save_to_state_dict()
654-
param: packed -- returns dict[str, torch.Tensor] for state_dict
654+
param: packed -- returns dict[str, torch.Tensor] for state_dict fit for safetensors saving
655655
"""
656656
qs_dict = {
657657
'quant_type': self.quant_type,
@@ -665,13 +665,14 @@ def as_dict(self, packed=False):
665665
qs_dict.update({
666666
'nested_absmax': self.state2.absmax,
667667
'nested_blocksize': self.state2.blocksize,
668-
'nested_quant_map': self.state2.code,
668+
'nested_quant_map': self.state2.code.clone(), # un-shared to avoid restoring it after shared tensors are removed by safetensors
669669
'nested_dtype': str(self.state2.dtype).strip('torch.'),
670670
'nested_offset': self.offset.item(),
671671
})
672672
if not packed:
673673
return qs_dict
674674

675+
# packed format allows serialization of non-tensor components, critical for saving in safetensors format
675676
qs_packed_dict = {k: v for k, v in qs_dict.items() if isinstance(v, torch.Tensor)}
676677
non_tensor_dict = {k: v for k, v in qs_dict.items() if not isinstance(v, torch.Tensor)}
677678
qs_packed_dict["quant_state." + "bitsandbytes__" + self.quant_type] = pack_dict_to_tensor(non_tensor_dict)

0 commit comments

Comments
 (0)