Skip to content

Commit 6cf0f05

Browse files
committed
rework of non-tensor qs items storage
1 parent 6a934d4 commit 6cf0f05

File tree

3 files changed

+85
-35
lines changed

3 files changed

+85
-35
lines changed

bitsandbytes/functional.py

Lines changed: 46 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,9 @@
1313
import numpy as np
1414

1515
from functools import reduce # Required in Python 3
16-
from typing import Tuple
16+
from typing import Tuple, Any
1717
from torch import Tensor
18+
from bitsandbytes.utils import pack_dict_to_tensor, unpack_tensor_to_dict
1819

1920
from .cextension import COMPILED_WITH_CUDA, lib
2021

@@ -580,58 +581,77 @@ def __init__(self, absmax, shape=None, code=None, blocksize=None, quant_type=Non
580581
self.nested = state2 is not None
581582

582583
@classmethod
583-
def from_dict(cls, quant_state_dict: dict[str, torch.Tensor], device: torch.device) -> 'QuantState':
584+
def from_dict(cls, qs_dict: dict[str, Any], device: torch.device) -> 'QuantState':
584585
"""
585586
unpacks dict of tensors into QuantState
586587
where necessary, convert into strings, torch.dtype, ints, etc.
588+
589+
quant_state_dict may contain item with non-tensor components with key like
590+
`...weight.quant_state.bitsandbytes__[nf4/fp4]`
591+
it is detected with key strored in qs_key, and then unpacked
587592
"""
588593

589-
quant_state_dict = {k.split('.')[-1]:v for k, v in quant_state_dict.items()}
590-
if 'quant_state_dict' in quant_state_dict:
591-
quant_state_dict|= quant_state_dict.pop('quant_state_dict')
594+
# unpacking tensor with non-tensor components
595+
qs_key = [k for k, v in qs_dict.items() if "quant_state" in k and isinstance(v, torch.Tensor)]
596+
if len(qs_key) == 1:
597+
qs_key = qs_key[0]
598+
assert 'bitsandbytes__nf4' in qs_key or 'bitsandbytes__fp4' in qs_key, \
599+
f"invalid qs_key value {qs_key}"
600+
qs_dict |= unpack_tensor_to_dict(qs_dict.pop(qs_key))
601+
602+
qs_dict = {k.split('.')[-1]:v for k, v in qs_dict.items()} # strip prefixes
592603

593-
if 'nested_absmax' in quant_state_dict:
594-
offset = torch.tensor(float(quant_state_dict['nested_offset'])).to(device)
604+
if 'nested_absmax' in qs_dict:
605+
offset = torch.tensor(float(qs_dict['nested_offset'])).to(device)
595606
state2 = cls(
596-
absmax=quant_state_dict['nested_absmax'].to(device),
597-
code=quant_state_dict['nested_code'].to(device),
598-
blocksize=int(quant_state_dict['nested_blocksize']),
599-
dtype=getattr(torch, quant_state_dict['nested_dtype']),
607+
absmax=qs_dict['nested_absmax'].to(device),
608+
code=qs_dict['nested_code'].to(device),
609+
blocksize=qs_dict['nested_blocksize'],
610+
dtype=getattr(torch, qs_dict['nested_dtype']),
600611
)
601612
else:
602613
offset, state2 = None, None
603614

604615
quant_state = cls(
605-
absmax=quant_state_dict['absmax'].to(device),
606-
shape=torch.Size(map(int, quant_state_dict['shape'].split('.'))),
607-
dtype=getattr(torch, quant_state_dict['dtype']),
608-
blocksize=int(quant_state_dict['blocksize']),
616+
absmax=qs_dict['absmax'].to(device),
617+
shape=torch.Size(qs_dict['shape']),
618+
dtype=getattr(torch, qs_dict['dtype']),
619+
blocksize=qs_dict['blocksize'],
609620
offset=offset,
610621
state2=state2,
611-
quant_type=quant_state_dict['quant_type'],
612-
code=quant_state_dict['code'].to(device),
622+
quant_type=qs_dict['quant_type'],
623+
code=qs_dict['code'].to(device),
613624
)
614625
return quant_state
615626

616-
def as_dict(self):
617-
"""dict of tensors and strings to use in serialization via _save_to_state_dict()"""
627+
def as_dict(self, packed=False):
628+
"""
629+
returns dict of tensors and strings to use in serialization via _save_to_state_dict()
630+
param: packed -- returns dict[str, torch.Tensor] for state_dict
631+
"""
618632
qs_dict = {
619633
'absmax': self.absmax,
620634
'code': self.code,
621-
'shape': ','.join(map(str, self.shape)),
622-
'dtype': str(self.dtype).strip('torch'),
623-
'blocksize': str(self.blocksize),
635+
'shape': tuple(self.shape),
636+
'dtype': str(self.dtype).strip('torch.'),
637+
'blocksize': self.blocksize,
624638
'quant_type': self.quant_type,
625639
}
626640
if self.nested:
627641
qs_dict.update({
628642
'nested_absmax': self.state2.absmax,
629643
'nested_code': self.state2.code,
630-
'nested_offset': f"{self.offset.item()}",
631-
'nested_blocksize': str(self.state2.blocksize),
632-
'nested_dtype': str(self.state2.dtype).strip('torch'),
644+
'nested_offset': self.offset.item(),
645+
'nested_blocksize': self.state2.blocksize,
646+
'nested_dtype': str(self.state2.dtype).strip('torch.'),
633647
})
634-
return qs_dict
648+
if not packed:
649+
return qs_dict
650+
651+
qs_packed_dict = {k: v for k, v in qs_dict.items() if isinstance(v, torch.Tensor)}
652+
non_tensor_dict = {k: v for k, v in qs_dict.items() if not isinstance(v, torch.Tensor)}
653+
qs_packed_dict["quant_state." + "bitsandbytes__" + self.quant_type] = pack_dict_to_tensor(non_tensor_dict)
654+
return qs_packed_dict
635655

636656
def to(self, device):
637657
# make sure the quantization state is on the right device

bitsandbytes/nn/modules.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ def from_prequantized(cls, quantized_stats, data=None, requires_grad=False, devi
159159
data = quantized_stats.pop('weight')
160160
self = torch.Tensor._make_subclass(cls, data.to(device))
161161
self.requires_grad = requires_grad
162-
self.quant_state = QuantState.from_dict(quant_state_dict=quantized_stats, device=device)
162+
self.quant_state = QuantState.from_dict(qs_dict=quantized_stats, device=device)
163163
self.blocksize = self.quant_state.blocksize
164164
self.compress_statistics = self.quant_state.nested
165165
self.quant_type = self.quant_state.quant_type
@@ -226,18 +226,14 @@ def set_compute_type(self, x):
226226

227227
def _save_to_state_dict(self, destination, prefix, keep_vars):
228228
"""
229-
besides weight and bias,
230-
fill state_dict with components of quant_state
229+
save weight and bias,
230+
then fill state_dict with components of quant_state
231231
"""
232232
super()._save_to_state_dict(destination, prefix, keep_vars) # saving weight and bias
233233

234234
if getattr(self.weight, "quant_state", None) is not None:
235-
quant_state_dict = self.weight.quant_state.as_dict()
236-
tensor_keys = [k for k, v in quant_state_dict.items() if isinstance(v, torch.Tensor)]
237-
for k in tensor_keys:
238-
destination[prefix + "weight." + k] = quant_state_dict.pop(k) if keep_vars else quant_state_dict.pop(k).detach()
239-
destination[prefix + "weight." + "quant_state_dict"] = quant_state_dict
240-
destination[prefix + "weight." + "quantization_method"] = "bitsandbytes." + quant_state_dict["quant_type"]
235+
for k, v in self.weight.quant_state.as_dict(packed=True).items():
236+
destination[prefix + "weight." + k] = v if keep_vars else v.detach()
241237

242238
def forward(self, x: torch.Tensor):
243239
# weights are cast automatically as Int8Params, but the bias has to be cast manually

bitsandbytes/utils.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import json
12
import shlex
23
import subprocess
34
import torch
@@ -158,3 +159,36 @@ def replace_linear(model, linear_replacement, skip_modules=["lm_head"], copy_wei
158159
if func is not None: func(module)
159160
return model
160161

162+
163+
def pack_dict_to_tensor(source_dict):
164+
"""
165+
Pack a dictionary into a torch tensor for storing quant_state items in state_dict.
166+
167+
Parameters:
168+
- source_dict: The dictionary to be packed.
169+
170+
Returns:
171+
A torch tensor containing the packed data.
172+
"""
173+
json_str = json.dumps(source_dict)
174+
json_bytes = json_str.encode('utf-8')
175+
tensor_data = torch.tensor(list(json_bytes), dtype=torch.uint8)
176+
177+
return tensor_data
178+
179+
180+
def unpack_tensor_to_dict(tensor_data):
181+
"""
182+
Unpack a torch tensor into a Python dictionary.
183+
184+
Parameters:
185+
- tensor_data: The torch tensor containing the packed data.
186+
187+
Returns:
188+
A Python dictionary containing the unpacked data.
189+
"""
190+
json_bytes = bytes(tensor_data.numpy())
191+
json_str = json_bytes.decode('utf-8')
192+
unpacked_dict = json.loads(json_str)
193+
194+
return unpacked_dict

0 commit comments

Comments
 (0)