@@ -571,14 +571,14 @@ class QuantState:
571571 """container for quantization state components to work with Params4bit and similar clases"""
572572 valid_quant_types = ('fp4' , 'nf4' )
573573 valid_qs_type_keys = [f"quant_state.bitsandbytes__{ x } " for x in valid_quant_types ]
574- valid_qs_keys = ['absmax' , 'code ' , 'nested_absmax' , 'nested_code ' , 'quant_state' ,
574+ valid_qs_keys = ['absmax' , 'quant_map ' , 'nested_absmax' , 'nested_quant_map ' , 'quant_state' ,
575575 'quant_type' , 'blocksize' , 'dtype' , 'shape' , 'nested_blocksize' , 'nested_dtype' , 'nested_offset' ]
576576
577577
578578 def __init__ (self , absmax , shape = None , code = None , blocksize = None , quant_type = None , dtype = None , offset = None , state2 = None ):
579579 self .absmax = absmax
580580 self .shape = shape
581- self .code = code # TODO consider renaming to `buckets / centroids / scale`
581+ self .code = code
582582 self .dtype = dtype
583583 self .blocksize = blocksize
584584 self .quant_type = quant_type
@@ -627,7 +627,7 @@ def from_dict(cls, qs_dict: dict[str, Any], device: torch.device) -> 'QuantState
627627 state2 = cls (
628628 absmax = qs_dict ['nested_absmax' ].to (device ),
629629 blocksize = qs_dict ['nested_blocksize' ],
630- code = qs_dict ['nested_code ' ].to (device ),
630+ code = qs_dict ['nested_quant_map ' ].to (device ),
631631 dtype = getattr (torch , qs_dict ['nested_dtype' ]),
632632 )
633633 else :
@@ -637,7 +637,7 @@ def from_dict(cls, qs_dict: dict[str, Any], device: torch.device) -> 'QuantState
637637 quant_type = qs_dict ['quant_type' ],
638638 absmax = qs_dict ['absmax' ].to (device ),
639639 blocksize = qs_dict ['blocksize' ],
640- code = qs_dict ['code ' ].to (device ),
640+ code = qs_dict ['quant_map ' ].to (device ),
641641 dtype = getattr (torch , qs_dict ['dtype' ]),
642642 shape = torch .Size (qs_dict ['shape' ]),
643643 offset = offset ,
@@ -654,15 +654,15 @@ def as_dict(self, packed=False):
654654 'quant_type' : self .quant_type ,
655655 'absmax' : self .absmax ,
656656 'blocksize' : self .blocksize ,
657- 'code ' : self .code ,
657+ 'quant_map ' : self .code ,
658658 'dtype' : str (self .dtype ).strip ('torch.' ),
659659 'shape' : tuple (self .shape ) if self .nested else None ,
660660 }
661661 if self .nested :
662662 qs_dict .update ({
663663 'nested_absmax' : self .state2 .absmax ,
664664 'nested_blocksize' : self .state2 .blocksize ,
665- 'nested_code ' : self .state2 .code ,
665+ 'nested_quant_map ' : self .state2 .code ,
666666 'nested_dtype' : str (self .state2 .dtype ).strip ('torch.' ),
667667 'nested_offset' : self .offset .item (),
668668 })
0 commit comments