@@ -567,6 +567,7 @@ def estimate_quantiles(A: Tensor, out: Tensor = None, offset: float = 1 / 512, n
567567 return out
568568
569569class QuantState :
570+ """container for quantizationstate components to work with Params4bit and similar clases"""
570571 def __init__ (self , absmax , shape = None , code = None , blocksize = None , quant_type = None , dtype = None , offset = None , state2 = None ):
571572 self .absmax = absmax
572573 self .shape = shape
@@ -579,32 +580,35 @@ def __init__(self, absmax, shape=None, code=None, blocksize=None, quant_type=Non
579580 self .nested = state2 is not None
580581
581582 @classmethod
582- def from_kwargs (cls , kwargs , device ):
583-
583+ def from_dict (cls , quant_state_dict : dict [str , torch .Tensor ], device : torch .device ) -> 'QuantState' :
584+ """
585+ unpacks dict of tensors into QuantState
586+ where necessary, convert into strings, torch.dtype, ints, etc.
587+ """
584588 tensor2str = lambda xx : '' .join ([chr (x ) for x in xx ]).strip ('.' )
585589
586- kwargs = {k .split ('.' )[- 1 ] :v for k , v in kwargs .items ()}
590+ quant_state_dict = {k .split ('.' )[- 1 ] :v for k , v in quant_state_dict .items ()}
587591
588- if 'nested_absmax' in kwargs :
589- offset = kwargs ['nested_offset' ]
592+ if 'nested_absmax' in quant_state_dict :
593+ offset = quant_state_dict ['nested_offset' ]
590594 state2 = cls (
591- absmax = kwargs ['nested_absmax' ].to (device ),
592- code = kwargs ['nested_code' ].to (device ),
593- blocksize = kwargs ['nested_blocksize' ].item (),
594- dtype = getattr (torch , tensor2str (kwargs ['nested_dtype' ])),
595+ absmax = quant_state_dict ['nested_absmax' ].to (device ),
596+ code = quant_state_dict ['nested_code' ].to (device ),
597+ blocksize = quant_state_dict ['nested_blocksize' ].item (),
598+ dtype = getattr (torch , tensor2str (quant_state_dict ['nested_dtype' ])),
595599 )
596600 else :
597601 offset , state2 = None , None
598602
599603 quant_state = cls (
600- absmax = kwargs ['absmax' ].to (device ),
601- shape = torch .Size (kwargs ['shape' ]),
602- dtype = getattr (torch , tensor2str (kwargs ['dtype' ])),
603- blocksize = kwargs ['blocksize' ].item (),
604+ absmax = quant_state_dict ['absmax' ].to (device ),
605+ shape = torch .Size (quant_state_dict ['shape' ]),
606+ dtype = getattr (torch , tensor2str (quant_state_dict ['dtype' ])),
607+ blocksize = quant_state_dict ['blocksize' ].item (),
604608 offset = offset ,
605609 state2 = state2 ,
606- quant_type = tensor2str (kwargs ['quant_type' ]),
607- code = kwargs ['code' ].to (device ),
610+ quant_type = tensor2str (quant_state_dict ['quant_type' ]),
611+ code = quant_state_dict ['code' ].to (device ),
608612 )
609613 return quant_state
610614
0 commit comments