Skip to content

Commit 7a117e4

Browse files
committed
cleanup1
1 parent 6cf0f05 commit 7a117e4

File tree

2 files changed

+14
-11
lines changed

2 files changed

+14
-11
lines changed

bitsandbytes/functional.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -593,6 +593,8 @@ def from_dict(cls, qs_dict: dict[str, Any], device: torch.device) -> 'QuantState
593593

594594
# unpacking tensor with non-tensor components
595595
qs_key = [k for k, v in qs_dict.items() if "quant_state" in k and isinstance(v, torch.Tensor)]
596+
assert len(qs_key) == 1 or not qs_key and 'quant_type' in qs_dict, \
597+
f"`qs_dict` must contain packed quant_state items, or be unpacked. Found keys: {tuple(qs_dict.keys())}"
596598
if len(qs_key) == 1:
597599
qs_key = qs_key[0]
598600
assert 'bitsandbytes__nf4' in qs_key or 'bitsandbytes__fp4' in qs_key, \
@@ -605,22 +607,22 @@ def from_dict(cls, qs_dict: dict[str, Any], device: torch.device) -> 'QuantState
605607
offset = torch.tensor(float(qs_dict['nested_offset'])).to(device)
606608
state2 = cls(
607609
absmax=qs_dict['nested_absmax'].to(device),
608-
code=qs_dict['nested_code'].to(device),
609610
blocksize=qs_dict['nested_blocksize'],
611+
code=qs_dict['nested_code'].to(device),
610612
dtype=getattr(torch, qs_dict['nested_dtype']),
611613
)
612614
else:
613615
offset, state2 = None, None
614616

615617
quant_state = cls(
618+
quant_type=qs_dict['quant_type'],
616619
absmax=qs_dict['absmax'].to(device),
617-
shape=torch.Size(qs_dict['shape']),
618-
dtype=getattr(torch, qs_dict['dtype']),
619620
blocksize=qs_dict['blocksize'],
621+
code=qs_dict['code'].to(device),
622+
dtype=getattr(torch, qs_dict['dtype']),
623+
shape=torch.Size(qs_dict['shape']),
620624
offset=offset,
621625
state2=state2,
622-
quant_type=qs_dict['quant_type'],
623-
code=qs_dict['code'].to(device),
624626
)
625627
return quant_state
626628

@@ -630,20 +632,20 @@ def as_dict(self, packed=False):
630632
param: packed -- returns dict[str, torch.Tensor] for state_dict
631633
"""
632634
qs_dict = {
635+
'quant_type': self.quant_type,
633636
'absmax': self.absmax,
637+
'blocksize': self.blocksize,
634638
'code': self.code,
635-
'shape': tuple(self.shape),
636639
'dtype': str(self.dtype).strip('torch.'),
637-
'blocksize': self.blocksize,
638-
'quant_type': self.quant_type,
640+
'shape': tuple(self.shape) if self.nested else None,
639641
}
640642
if self.nested:
641643
qs_dict.update({
642644
'nested_absmax': self.state2.absmax,
643-
'nested_code': self.state2.code,
644-
'nested_offset': self.offset.item(),
645645
'nested_blocksize': self.state2.blocksize,
646+
'nested_code': self.state2.code,
646647
'nested_dtype': str(self.state2.dtype).strip('torch.'),
648+
'nested_offset': self.offset.item(),
647649
})
648650
if not packed:
649651
return qs_dict

bitsandbytes/nn/modules.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,8 @@ def __new__(cls, data=None, requires_grad=True, quant_state=None, blocksize=64,
156156
@classmethod
157157
def from_prequantized(cls, quantized_stats, data=None, requires_grad=False, device='cuda', **kwargs):
158158
if data is None:
159-
data = quantized_stats.pop('weight')
159+
weight_key = [k for k in quantized_stats if k.endswith(".weight")][0]
160+
data = quantized_stats.pop(weight_key)
160161
self = torch.Tensor._make_subclass(cls, data.to(device))
161162
self.requires_grad = requires_grad
162163
self.quant_state = QuantState.from_dict(qs_dict=quantized_stats, device=device)

0 commit comments

Comments
 (0)