Skip to content

Commit 48b3e77

Browse files
committed
some renaming
1 parent 5bcc1dd commit 48b3e77

File tree

2 files changed

+20
-16
lines changed

2 files changed

+20
-16
lines changed

bitsandbytes/functional.py

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -567,6 +567,7 @@ def estimate_quantiles(A: Tensor, out: Tensor = None, offset: float = 1 / 512, n
567567
return out
568568

569569
class QuantState:
570+
"""container for quantizationstate components to work with Params4bit and similar clases"""
570571
def __init__(self, absmax, shape=None, code=None, blocksize=None, quant_type=None, dtype=None, offset=None, state2=None):
571572
self.absmax = absmax
572573
self.shape = shape
@@ -579,32 +580,35 @@ def __init__(self, absmax, shape=None, code=None, blocksize=None, quant_type=Non
579580
self.nested = state2 is not None
580581

581582
@classmethod
582-
def from_kwargs(cls, kwargs, device):
583-
583+
def from_dict(cls, quant_state_dict: dict[str, torch.Tensor], device: torch.device) -> 'QuantState':
584+
"""
585+
unpacks dict of tensors into QuantState
586+
where necessary, convert into strings, torch.dtype, ints, etc.
587+
"""
584588
tensor2str = lambda xx: ''.join([chr(x) for x in xx]).strip('.')
585589

586-
kwargs = {k.split('.')[-1] :v for k, v in kwargs.items()}
590+
quant_state_dict = {k.split('.')[-1] :v for k, v in quant_state_dict.items()}
587591

588-
if 'nested_absmax' in kwargs:
589-
offset = kwargs['nested_offset']
592+
if 'nested_absmax' in quant_state_dict:
593+
offset = quant_state_dict['nested_offset']
590594
state2 = cls(
591-
absmax=kwargs['nested_absmax'].to(device),
592-
code=kwargs['nested_code'].to(device),
593-
blocksize=kwargs['nested_blocksize'].item(),
594-
dtype=getattr(torch, tensor2str(kwargs['nested_dtype'])),
595+
absmax=quant_state_dict['nested_absmax'].to(device),
596+
code=quant_state_dict['nested_code'].to(device),
597+
blocksize=quant_state_dict['nested_blocksize'].item(),
598+
dtype=getattr(torch, tensor2str(quant_state_dict['nested_dtype'])),
595599
)
596600
else:
597601
offset, state2 = None, None
598602

599603
quant_state = cls(
600-
absmax=kwargs['absmax'].to(device),
601-
shape=torch.Size(kwargs['shape']),
602-
dtype=getattr(torch, tensor2str(kwargs['dtype'])),
603-
blocksize=kwargs['blocksize'].item(),
604+
absmax=quant_state_dict['absmax'].to(device),
605+
shape=torch.Size(quant_state_dict['shape']),
606+
dtype=getattr(torch, tensor2str(quant_state_dict['dtype'])),
607+
blocksize=quant_state_dict['blocksize'].item(),
604608
offset=offset,
605609
state2=state2,
606-
quant_type=tensor2str(kwargs['quant_type']),
607-
code=kwargs['code'].to(device),
610+
quant_type=tensor2str(quant_state_dict['quant_type']),
611+
code=quant_state_dict['code'].to(device),
608612
)
609613
return quant_state
610614

bitsandbytes/nn/modules.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ def from_prequantized(cls, quantized_stats, data=None, requires_grad=False, devi
159159
data = quantized_stats.pop('weight')
160160
self = torch.Tensor._make_subclass(cls, data.to(device))
161161
self.requires_grad = requires_grad
162-
self.quant_state = QuantState.from_kwargs(kwargs=quantized_stats, device=device)
162+
self.quant_state = QuantState.from_dict(quant_state_dict=quantized_stats, device=device)
163163
self.blocksize = self.quant_state.blocksize
164164
self.compress_statistics = self.quant_state.nested
165165
self.quant_type = self.quant_state.quant_type

0 commit comments

Comments
 (0)