Skip to content

Commit 851806e

Browse files
committed
renamed code to quant_map in serialized QState
1 parent 76b40a5 commit 851806e

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

bitsandbytes/functional.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -571,14 +571,14 @@ class QuantState:
571571
"""container for quantization state components to work with Params4bit and similar clases"""
572572
valid_quant_types = ('fp4', 'nf4')
573573
valid_qs_type_keys = [f"quant_state.bitsandbytes__{x}" for x in valid_quant_types]
574-
valid_qs_keys = ['absmax', 'code', 'nested_absmax', 'nested_code', 'quant_state',
574+
valid_qs_keys = ['absmax', 'quant_map', 'nested_absmax', 'nested_quant_map', 'quant_state',
575575
'quant_type', 'blocksize', 'dtype', 'shape', 'nested_blocksize', 'nested_dtype', 'nested_offset']
576576

577577

578578
def __init__(self, absmax, shape=None, code=None, blocksize=None, quant_type=None, dtype=None, offset=None, state2=None):
579579
self.absmax = absmax
580580
self.shape = shape
581-
self.code = code # TODO consider renaming to `buckets / centroids / scale`
581+
self.code = code
582582
self.dtype = dtype
583583
self.blocksize = blocksize
584584
self.quant_type = quant_type
@@ -627,7 +627,7 @@ def from_dict(cls, qs_dict: dict[str, Any], device: torch.device) -> 'QuantState
627627
state2 = cls(
628628
absmax=qs_dict['nested_absmax'].to(device),
629629
blocksize=qs_dict['nested_blocksize'],
630-
code=qs_dict['nested_code'].to(device),
630+
code=qs_dict['nested_quant_map'].to(device),
631631
dtype=getattr(torch, qs_dict['nested_dtype']),
632632
)
633633
else:
@@ -637,7 +637,7 @@ def from_dict(cls, qs_dict: dict[str, Any], device: torch.device) -> 'QuantState
637637
quant_type=qs_dict['quant_type'],
638638
absmax=qs_dict['absmax'].to(device),
639639
blocksize=qs_dict['blocksize'],
640-
code=qs_dict['code'].to(device),
640+
code=qs_dict['quant_map'].to(device),
641641
dtype=getattr(torch, qs_dict['dtype']),
642642
shape=torch.Size(qs_dict['shape']),
643643
offset=offset,
@@ -654,15 +654,15 @@ def as_dict(self, packed=False):
654654
'quant_type': self.quant_type,
655655
'absmax': self.absmax,
656656
'blocksize': self.blocksize,
657-
'code': self.code,
657+
'quant_map': self.code,
658658
'dtype': str(self.dtype).strip('torch.'),
659659
'shape': tuple(self.shape) if self.nested else None,
660660
}
661661
if self.nested:
662662
qs_dict.update({
663663
'nested_absmax': self.state2.absmax,
664664
'nested_blocksize': self.state2.blocksize,
665-
'nested_code': self.state2.code,
665+
'nested_quant_map': self.state2.code,
666666
'nested_dtype': str(self.state2.dtype).strip('torch.'),
667667
'nested_offset': self.offset.item(),
668668
})

0 commit comments

Comments
 (0)