@@ -567,14 +567,14 @@ def estimate_quantiles(A: Tensor, out: Tensor = None, offset: float = 1 / 512, n
567567
568568 return out
569569
570+
570571class QuantState :
571572 """container for quantization state components to work with Params4bit and similar clases"""
572573 valid_quant_types = ('fp4' , 'nf4' )
573574 valid_qs_type_keys = [f"quant_state.bitsandbytes__{ x } " for x in valid_quant_types ]
574575 valid_qs_keys = ['absmax' , 'quant_map' , 'nested_absmax' , 'nested_quant_map' , 'quant_state' ,
575576 'quant_type' , 'blocksize' , 'dtype' , 'shape' , 'nested_blocksize' , 'nested_dtype' , 'nested_offset' ]
576577
577-
578578 def __init__ (self , absmax , shape = None , code = None , blocksize = None , quant_type = None , dtype = None , offset = None , state2 = None ):
579579 self .absmax = absmax
580580 self .shape = shape
@@ -585,7 +585,7 @@ def __init__(self, absmax, shape=None, code=None, blocksize=None, quant_type=Non
585585 self .offset = offset
586586 self .state2 = state2
587587 self .nested = state2 is not None
588-
588+
589589 def __get_item__ (self , idx ):
590590 """
591591 ensures compatibility with older quant state scheme with nested lists.
@@ -598,15 +598,15 @@ def __get_item__(self, idx):
598598 else :
599599 list_repr = [self .absmax , self .shape , self .dtype , self .blocksize , None , self .quant_type ]
600600 return list_repr [idx ]
601-
601+
602602 @classmethod
603603 def from_dict (cls , qs_dict : Dict [str , Any ], device : torch .device ) -> 'QuantState' :
604604 """
605605 unpacks components of state_dict into QuantState
606606 where necessary, convert into strings, torch.dtype, ints, etc.
607607
608608 qs_dict: based on state_dict, with only relevant keys, striped of prefixes.
609-
609+
610610 item with key `quant_state.bitsandbytes__[nf4/fp4]` may contain minor and non-tensor quant state items.
611611 """
612612
@@ -615,8 +615,8 @@ def from_dict(cls, qs_dict: Dict[str, Any], device: torch.device) -> 'QuantState
615615 if not len (qs_key ) and 'quant_type' not in qs_dict :
616616 raise ValueError ("Expected packed or unpacked quant_state items, found neither" )
617617 elif len (qs_key ) != 1 :
618- raise ValueError (f"There should be exaclly one quant_state item with key from { self .valid_qs_type_keys } . Detected { len (qs_ley )} such items" )
619-
618+ raise ValueError (f"There should be exaclly one quant_state item with key from { cls .valid_qs_type_keys } . Detected { len (qs_key )} such items" )
619+
620620 # unpacking minor and non-tensor quant state items if necessary
621621 if len (qs_key ) == 1 :
622622 qs_key = qs_key [0 ]
@@ -673,7 +673,7 @@ def as_dict(self, packed=False):
673673 non_tensor_dict = {k : v for k , v in qs_dict .items () if not isinstance (v , torch .Tensor )}
674674 qs_packed_dict ["quant_state." + "bitsandbytes__" + self .quant_type ] = pack_dict_to_tensor (non_tensor_dict )
675675 return qs_packed_dict
676-
676+
677677 def to (self , device ):
678678 # make sure the quantization state is on the right device
679679 self .absmax = self .absmax .to (device )
@@ -682,6 +682,7 @@ def to(self, device):
682682 self .state2 .absmax = self .state2 .absmax .to (device )
683683 self .state2 .code = self .state2 .code .to (device )
684684
685+
685686def quantize_blockwise (A : Tensor , code : Tensor = None , absmax : Tensor = None , out : Tensor = None , blocksize = 4096 , nested = False ) -> Tensor :
686687 """
687688 Quantize tensor A in blocks of size 4096 values.
0 commit comments