Skip to content

Commit 1d541b5

Browse files
committed
[WiP] rework of Q_state save format
1 parent 48b3e77 commit 1d541b5

File tree

2 files changed

+41
-38
lines changed

2 files changed

+41
-38
lines changed

bitsandbytes/functional.py

Lines changed: 32 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -585,33 +585,54 @@ def from_dict(cls, quant_state_dict: dict[str, torch.Tensor], device: torch.devi
585585
unpacks dict of tensors into QuantState
586586
where necessary, convert into strings, torch.dtype, ints, etc.
587587
"""
588-
tensor2str = lambda xx: ''.join([chr(x) for x in xx]).strip('.')
589588

590-
quant_state_dict = {k.split('.')[-1] :v for k, v in quant_state_dict.items()}
591-
589+
quant_state_dict = {k.split('.')[-1]:v for k, v in quant_state_dict.items()}
590+
if 'quant_state_dict' in quant_state_dict:
591+
quant_state_dict|= quant_state_dict.pop('quant_state_dict')
592+
592593
if 'nested_absmax' in quant_state_dict:
593-
offset = quant_state_dict['nested_offset']
594+
offset = torch.tensor(float(quant_state_dict['nested_offset'])).to(device)
594595
state2 = cls(
595596
absmax=quant_state_dict['nested_absmax'].to(device),
596597
code=quant_state_dict['nested_code'].to(device),
597-
blocksize=quant_state_dict['nested_blocksize'].item(),
598-
dtype=getattr(torch, tensor2str(quant_state_dict['nested_dtype'])),
598+
blocksize=int(quant_state_dict['nested_blocksize']),
599+
dtype=getattr(torch, quant_state_dict['nested_dtype']),
599600
)
600601
else:
601602
offset, state2 = None, None
602603

603604
quant_state = cls(
604-
absmax=quant_state_dict['absmax'].to(device),
605-
shape=torch.Size(quant_state_dict['shape']),
606-
dtype=getattr(torch, tensor2str(quant_state_dict['dtype'])),
607-
blocksize=quant_state_dict['blocksize'].item(),
605+
absmax=quant_state_dict['absmax'].to(device),
606+
shape=torch.Size(map(int, quant_state_dict['shape'].split('.'))),
607+
dtype=getattr(torch, quant_state_dict['dtype']),
608+
blocksize=int(quant_state_dict['blocksize']),
608609
offset=offset,
609610
state2=state2,
610-
quant_type=tensor2str(quant_state_dict['quant_type']),
611+
quant_type=quant_state_dict['quant_type'],
611612
code=quant_state_dict['code'].to(device),
612613
)
613614
return quant_state
614615

616+
def as_dict(self):
617+
"""dict of tensors and strings to use in serialization via _save_to_state_dict()"""
618+
qs_dict = {
619+
'absmax': self.absmax,
620+
'code': self.code,
621+
'shape': ','.join(map(str, self.shape)),
622+
'dtype': str(self.dtype).strip('torch'),
623+
'blocksize': str(self.blocksize),
624+
'quant_type': self.quant_type,
625+
}
626+
if self.nested:
627+
qs_dict.update({
628+
'nested_absmax': self.state2.absmax,
629+
'nested_code': self.state2.code,
630+
'nested_offset': f"{self.offset.item()}",
631+
'nested_blocksize': str(self.state2.blocksize),
632+
'nested_dtype': str(self.state2.dtype).strip('torch'),
633+
})
634+
return qs_dict
635+
615636
def to(self, device):
616637
# make sure the quantization state is on the right device
617638
self.absmax = self.absmax.to(device)

bitsandbytes/nn/modules.py

Lines changed: 9 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -224,36 +224,18 @@ def set_compute_type(self, x):
224224
warnings.warn(f'Input type into Linear4bit is torch.float16, but bnb_4bit_compute_type=torch.float32 (default). This will lead to slow inference or training speed.')
225225
warnings.filterwarnings('ignore', message='.*inference or training')
226226

227-
228-
def _update_buffers(self):
229-
230-
def string_to_tensor(s):
231-
"""stores string as ints for serialization. assumes codes fit int16"""
232-
return torch.tensor([ord(x) for x in s], dtype=torch.int16)
233-
234-
if getattr(self.weight, 'quant_state', None) is not None:
235-
weight_quant_state = self.weight.quant_state
236-
self.register_buffer('absmax', weight_quant_state.absmax)
237-
self.register_buffer('shape', torch.tensor(weight_quant_state.shape))
238-
self.register_buffer('dtype', string_to_tensor(str(weight_quant_state.dtype).strip('torch')))
239-
self.register_buffer('blocksize', torch.tensor(weight_quant_state.blocksize))
240-
self.register_buffer('quant_type', string_to_tensor(weight_quant_state.quant_type))
241-
self.register_buffer('code', weight_quant_state.code)
242-
243-
if weight_quant_state.nested:
244-
self.register_buffer('nested_offset', weight_quant_state.offset)
245-
self.register_buffer('nested_absmax', weight_quant_state.state2.absmax)
246-
self.register_buffer('nested_code', weight_quant_state.state2.code)
247-
self.register_buffer('nested_blocksize', torch.tensor(weight_quant_state.state2.blocksize))
248-
self.register_buffer('nested_dtype', string_to_tensor(str(weight_quant_state.state2.dtype).strip('torch')))
249-
250-
251227
def _save_to_state_dict(self, destination, prefix, keep_vars):
252228
"""
253-
fill state_dict with components of nf4
254-
TODO: test with other 4-bit Q-types
229+
besides weight and bias,
230+
fill state_dict with components of quant_state
255231
"""
256-
self._update_buffers() # link the quant_state items with _buffers
232+
if getattr(self.weight, "quant_state", None) is not None:
233+
quant_state_dict = self.weight.quant_state.as_dict()
234+
tensor_keys = [k for k, v in quant_state_dict.items() if isinstance(v, torch.Tensor)]
235+
for k in tensor_keys:
236+
destination[prefix + "weight." + k] = quant_state_dict.pop(k) if keep_vars else quant_state_dict.pop(k).detach()
237+
destination[prefix + "weight." + "quant_state_dict"] = quant_state_dict
238+
destination[prefix + "weight." + "quantization_method"] = "bitsandbytes." + quant_state_dict["quant_type"]
257239
super()._save_to_state_dict(destination, prefix, keep_vars) # saving weight and bias
258240

259241
def forward(self, x: torch.Tensor):

0 commit comments

Comments
 (0)