@@ -593,6 +593,8 @@ def from_dict(cls, qs_dict: dict[str, Any], device: torch.device) -> 'QuantState
593593
594594 # unpacking tensor with non-tensor components
595595 qs_key = [k for k , v in qs_dict .items () if "quant_state" in k and isinstance (v , torch .Tensor )]
596+ assert len (qs_key ) == 1 or not qs_key and 'quant_type' in qs_dict , \
597+ f"`qs_dict` must contain packed quant_state items, or be unpacked. Found keys: { tuple (qs_dict .keys ())} "
596598 if len (qs_key ) == 1 :
597599 qs_key = qs_key [0 ]
598600 assert 'bitsandbytes__nf4' in qs_key or 'bitsandbytes__fp4' in qs_key , \
@@ -605,22 +607,22 @@ def from_dict(cls, qs_dict: dict[str, Any], device: torch.device) -> 'QuantState
605607 offset = torch .tensor (float (qs_dict ['nested_offset' ])).to (device )
606608 state2 = cls (
607609 absmax = qs_dict ['nested_absmax' ].to (device ),
608- code = qs_dict ['nested_code' ].to (device ),
609610 blocksize = qs_dict ['nested_blocksize' ],
611+ code = qs_dict ['nested_code' ].to (device ),
610612 dtype = getattr (torch , qs_dict ['nested_dtype' ]),
611613 )
612614 else :
613615 offset , state2 = None , None
614616
615617 quant_state = cls (
618+ quant_type = qs_dict ['quant_type' ],
616619 absmax = qs_dict ['absmax' ].to (device ),
617- shape = torch .Size (qs_dict ['shape' ]),
618- dtype = getattr (torch , qs_dict ['dtype' ]),
619620 blocksize = qs_dict ['blocksize' ],
621+ code = qs_dict ['code' ].to (device ),
622+ dtype = getattr (torch , qs_dict ['dtype' ]),
623+ shape = torch .Size (qs_dict ['shape' ]),
620624 offset = offset ,
621625 state2 = state2 ,
622- quant_type = qs_dict ['quant_type' ],
623- code = qs_dict ['code' ].to (device ),
624626 )
625627 return quant_state
626628
@@ -630,20 +632,20 @@ def as_dict(self, packed=False):
630632 param: packed -- returns dict[str, torch.Tensor] for state_dict
631633 """
632634 qs_dict = {
635+ 'quant_type' : self .quant_type ,
633636 'absmax' : self .absmax ,
637+ 'blocksize' : self .blocksize ,
634638 'code' : self .code ,
635- 'shape' : tuple (self .shape ),
636639 'dtype' : str (self .dtype ).strip ('torch.' ),
637- 'blocksize' : self .blocksize ,
638- 'quant_type' : self .quant_type ,
640+ 'shape' : tuple (self .shape ) if self .nested else None ,
639641 }
640642 if self .nested :
641643 qs_dict .update ({
642644 'nested_absmax' : self .state2 .absmax ,
643- 'nested_code' : self .state2 .code ,
644- 'nested_offset' : self .offset .item (),
645645 'nested_blocksize' : self .state2 .blocksize ,
646+ 'nested_code' : self .state2 .code ,
646647 'nested_dtype' : str (self .state2 .dtype ).strip ('torch.' ),
648+ 'nested_offset' : self .offset .item (),
647649 })
648650 if not packed :
649651 return qs_dict
0 commit comments