@@ -585,33 +585,54 @@ def from_dict(cls, quant_state_dict: dict[str, torch.Tensor], device: torch.devi
585585 unpacks dict of tensors into QuantState
586586 where necessary, convert into strings, torch.dtype, ints, etc.
587587 """
588- tensor2str = lambda xx : '' .join ([chr (x ) for x in xx ]).strip ('.' )
589588
590- quant_state_dict = {k .split ('.' )[- 1 ] :v for k , v in quant_state_dict .items ()}
591-
589+ quant_state_dict = {k .split ('.' )[- 1 ]:v for k , v in quant_state_dict .items ()}
590+ if 'quant_state_dict' in quant_state_dict :
591+ quant_state_dict |= quant_state_dict .pop ('quant_state_dict' )
592+
592593 if 'nested_absmax' in quant_state_dict :
593- offset = quant_state_dict ['nested_offset' ]
594+ offset = torch . tensor ( float ( quant_state_dict ['nested_offset' ])). to ( device )
594595 state2 = cls (
595596 absmax = quant_state_dict ['nested_absmax' ].to (device ),
596597 code = quant_state_dict ['nested_code' ].to (device ),
597- blocksize = quant_state_dict ['nested_blocksize' ]. item ( ),
598- dtype = getattr (torch , tensor2str ( quant_state_dict ['nested_dtype' ]) ),
598+ blocksize = int ( quant_state_dict ['nested_blocksize' ]),
599+ dtype = getattr (torch , quant_state_dict ['nested_dtype' ]),
599600 )
600601 else :
601602 offset , state2 = None , None
602603
603604 quant_state = cls (
604- absmax = quant_state_dict ['absmax' ].to (device ),
605- shape = torch .Size (quant_state_dict ['shape' ]),
606- dtype = getattr (torch , tensor2str ( quant_state_dict ['dtype' ]) ),
607- blocksize = quant_state_dict ['blocksize' ]. item ( ),
605+ absmax = quant_state_dict ['absmax' ].to (device ),
606+ shape = torch .Size (map ( int , quant_state_dict ['shape' ]. split ( '.' )) ),
607+ dtype = getattr (torch , quant_state_dict ['dtype' ]),
608+ blocksize = int ( quant_state_dict ['blocksize' ]),
608609 offset = offset ,
609610 state2 = state2 ,
610- quant_type = tensor2str ( quant_state_dict ['quant_type' ]) ,
611+ quant_type = quant_state_dict ['quant_type' ],
611612 code = quant_state_dict ['code' ].to (device ),
612613 )
613614 return quant_state
614615
616+ def as_dict (self ):
617+ """dict of tensors and strings to use in serialization via _save_to_state_dict()"""
618+ qs_dict = {
619+ 'absmax' : self .absmax ,
620+ 'code' : self .code ,
621+ 'shape' : ',' .join (map (str , self .shape )),
622+ 'dtype' : str (self .dtype ).strip ('torch' ),
623+ 'blocksize' : str (self .blocksize ),
624+ 'quant_type' : self .quant_type ,
625+ }
626+ if self .nested :
627+ qs_dict .update ({
628+ 'nested_absmax' : self .state2 .absmax ,
629+ 'nested_code' : self .state2 .code ,
630+ 'nested_offset' : f"{ self .offset .item ()} " ,
631+ 'nested_blocksize' : str (self .state2 .blocksize ),
632+ 'nested_dtype' : str (self .state2 .dtype ).strip ('torch' ),
633+ })
634+ return qs_dict
635+
615636 def to (self , device ):
616637 # make sure the quantization state is on the right device
617638 self .absmax = self .absmax .to (device )
0 commit comments