Skip to content

Commit 7d1c9cf

Browse files
committed
extra feats in constructor Params4bit
1 parent f1ef74f commit 7d1c9cf

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

bitsandbytes/nn/modules.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,8 +168,11 @@ def from_state_dict(cls, state_dict, prefix="", requires_grad=False):
168168
if data.device.type != "cuda":
169169
raise ValueError(f"`data.device.type` must be 'cuda', detected {data.device.type}")
170170

171-
cls.requires_grad = requires_grad,
171+
cls.requires_grad = requires_grad
172172
cls.quant_state = QuantState.from_dict(qs_dict=qs_dict, device=data.device)
173+
cls.blocksize = cls.quant_state.blocksize
174+
cls.compress_statistics = cls.quant_state.nested
175+
cls.quant_type = cls.quant_state.quant_type
173176

174177
self = torch.Tensor._make_subclass(cls, data=data.to(data.device))
175178
return self, state_dict

0 commit comments

Comments
 (0)