Skip to content

Commit 7fed393

Browse files
Fix restoration of quant_storage for CPU offloading (#1279)
* Fix restoration of quant_storage for CPU offloading * Clarify comment on default quant_storage in Params4bit.from_prequantized() * fix to make quant_storage dynamic based on serialized dtype * delete obsolete comment --------- Co-authored-by: Titus von Koeller <[email protected]>
1 parent e3ae243 commit 7fed393

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

bitsandbytes/nn/modules.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -282,10 +282,13 @@ def from_prequantized(
282282
self.compress_statistics = self.quant_state.nested
283283
self.quant_type = self.quant_state.quant_type
284284
self.bnb_quantized = True
285+
286+
self.quant_storage = data.dtype
287+
285288
return self
286289

287290
def _quantize(self, device):
288-
w = self.data.contiguous().cuda(device)
291+
w = self.data.contiguous().to(device)
289292
w_4bit, quant_state = bnb.functional.quantize_4bit(
290293
w,
291294
blocksize=self.blocksize,
@@ -333,6 +336,7 @@ def to(self, *args, **kwargs):
333336
blocksize=self.blocksize,
334337
compress_statistics=self.compress_statistics,
335338
quant_type=self.quant_type,
339+
quant_storage=self.quant_storage,
336340
)
337341

338342
return new_param
@@ -450,7 +454,7 @@ def forward(self, x: torch.Tensor):
450454
# since we registered the module, we can recover the state here
451455
assert self.weight.shape[1] == 1
452456
if not isinstance(self.weight, Params4bit):
453-
self.weight = Params4bit(self.weight, quant_storage=self.quant_storage)
457+
self.weight = Params4bit(self.weight, quant_storage=self.quant_storage, bnb_quantized=True)
454458
self.weight.quant_state = self.quant_state
455459
else:
456460
print(

0 commit comments

Comments
 (0)