Skip to content

Commit c6d0a84

Browse files
committed
cleanup 0
1 parent 7d1c9cf commit c6d0a84

File tree

2 files changed

+18
-15
lines changed

2 files changed

+18
-15
lines changed

bitsandbytes/functional.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -567,14 +567,14 @@ def estimate_quantiles(A: Tensor, out: Tensor = None, offset: float = 1 / 512, n
567567

568568
return out
569569

570+
570571
class QuantState:
571572
"""container for quantization state components to work with Params4bit and similar clases"""
572573
valid_quant_types = ('fp4', 'nf4')
573574
valid_qs_type_keys = [f"quant_state.bitsandbytes__{x}" for x in valid_quant_types]
574575
valid_qs_keys = ['absmax', 'quant_map', 'nested_absmax', 'nested_quant_map', 'quant_state',
575576
'quant_type', 'blocksize', 'dtype', 'shape', 'nested_blocksize', 'nested_dtype', 'nested_offset']
576577

577-
578578
def __init__(self, absmax, shape=None, code=None, blocksize=None, quant_type=None, dtype=None, offset=None, state2=None):
579579
self.absmax = absmax
580580
self.shape = shape
@@ -585,7 +585,7 @@ def __init__(self, absmax, shape=None, code=None, blocksize=None, quant_type=Non
585585
self.offset = offset
586586
self.state2 = state2
587587
self.nested = state2 is not None
588-
588+
589589
def __get_item__(self, idx):
590590
"""
591591
ensures compatibility with older quant state scheme with nested lists.
@@ -598,15 +598,15 @@ def __get_item__(self, idx):
598598
else:
599599
list_repr = [self.absmax, self.shape, self.dtype, self.blocksize, None, self.quant_type]
600600
return list_repr[idx]
601-
601+
602602
@classmethod
603603
def from_dict(cls, qs_dict: Dict[str, Any], device: torch.device) -> 'QuantState':
604604
"""
605605
unpacks components of state_dict into QuantState
606606
where necessary, convert into strings, torch.dtype, ints, etc.
607607
608608
qs_dict: based on state_dict, with only relevant keys, striped of prefixes.
609-
609+
610610
item with key `quant_state.bitsandbytes__[nf4/fp4]` may contain minor and non-tensor quant state items.
611611
"""
612612

@@ -615,8 +615,8 @@ def from_dict(cls, qs_dict: Dict[str, Any], device: torch.device) -> 'QuantState
615615
if not len(qs_key) and 'quant_type' not in qs_dict:
616616
raise ValueError("Expected packed or unpacked quant_state items, found neither")
617617
elif len(qs_key) != 1:
618-
raise ValueError(f"There should be exaclly one quant_state item with key from {self.valid_qs_type_keys}. Detected {len(qs_ley)} such items")
619-
618+
raise ValueError(f"There should be exaclly one quant_state item with key from {cls.valid_qs_type_keys}. Detected {len(qs_key)} such items")
619+
620620
# unpacking minor and non-tensor quant state items if necessary
621621
if len(qs_key) == 1:
622622
qs_key = qs_key[0]
@@ -673,7 +673,7 @@ def as_dict(self, packed=False):
673673
non_tensor_dict = {k: v for k, v in qs_dict.items() if not isinstance(v, torch.Tensor)}
674674
qs_packed_dict["quant_state." + "bitsandbytes__" + self.quant_type] = pack_dict_to_tensor(non_tensor_dict)
675675
return qs_packed_dict
676-
676+
677677
def to(self, device):
678678
# make sure the quantization state is on the right device
679679
self.absmax = self.absmax.to(device)
@@ -682,6 +682,7 @@ def to(self, device):
682682
self.state2.absmax = self.state2.absmax.to(device)
683683
self.state2.code = self.state2.code.to(device)
684684

685+
685686
def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, out: Tensor = None, blocksize=4096, nested=False) -> Tensor:
686687
"""
687688
Quantize tensor A in blocks of size 4096 values.

bitsandbytes/nn/modules.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,7 @@ def forward(self, input: Tensor) -> Tensor:
139139

140140
return emb
141141

142+
142143
class Params4bit(torch.nn.Parameter):
143144

144145
def __new__(cls, data=None, requires_grad=True, quant_state=None, blocksize=64, compress_statistics=True, quant_type='fp4'):
@@ -152,27 +153,27 @@ def __new__(cls, data=None, requires_grad=True, quant_state=None, blocksize=64,
152153
self.quant_state = quant_state
153154
self.data = data
154155
return self
155-
156+
156157
@classmethod
157158
def from_state_dict(cls, state_dict, prefix="", requires_grad=False):
158159
data = state_dict.pop(prefix.rstrip('.'))
159-
160+
160161
# extracting components for QuantState from state_dict
161162
qs_dict = {}
162163
for k, v in state_dict.items():
163164
if k.replace(prefix, '').split('.')[0] in QuantState.valid_qs_keys:
164165
qs_dict[k] = v
165166
state_dict = {k: v for k, v in state_dict.items() if k not in qs_dict}
166167
qs_dict = {k.replace(prefix, ''): v for k, v in qs_dict.items()}
167-
168+
168169
if data.device.type != "cuda":
169170
raise ValueError(f"`data.device.type` must be 'cuda', detected {data.device.type}")
170171

171172
cls.requires_grad = requires_grad
172173
cls.quant_state = QuantState.from_dict(qs_dict=qs_dict, device=data.device)
173-
cls.blocksize = cls.quant_state.blocksize
174-
cls.compress_statistics = cls.quant_state.nested
175-
cls.quant_type = cls.quant_state.quant_type
174+
cls.blocksize = cls.quant_state.blocksize # this attribute can be deprecated - it duplicates same one in quant_state
175+
cls.compress_statistics = cls.quant_state.nested # this attribute can be deprecated - it duplicates quant_state.nested
176+
cls.quant_type = cls.quant_state.quant_type # this attribute can be deprecated - it duplicates same one in quant_state
176177

177178
self = torch.Tensor._make_subclass(cls, data=data.to(data.device))
178179
return self, state_dict
@@ -207,14 +208,15 @@ def to(self, *args, **kwargs):
207208
self.quant_state.to(device)
208209

209210
new_param = Params4bit(super().to(device=device, dtype=dtype, non_blocking=non_blocking),
210-
requires_grad=self.requires_grad, quant_state=self.quant_state,
211+
requires_grad=self.requires_grad, quant_state=self.quant_state,
211212
blocksize=self.blocksize, compress_statistics=self.compress_statistics,
212213
quant_type=self.quant_type)
213214

214215
return new_param
215216

217+
216218
class Linear4bit(nn.Linear):
217-
219+
218220
def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True, quant_type='fp4',device=None):
219221
super().__init__(input_features, output_features, bias, device)
220222
self.weight = Params4bit(self.weight.data, requires_grad=False, compress_statistics=compress_statistics, quant_type=quant_type)

0 commit comments

Comments
 (0)