Skip to content

Commit 87f88af

Browse files
Enable loading prequantized weights with bf16/fp16/fp32 quant_storage type for FSDP
1 parent 2621e1a commit 87f88af

File tree

1 file changed

+5
-0
lines changed

1 file changed

+5
-0
lines changed

bitsandbytes/nn/modules.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,7 @@ def from_prequantized(
273273
quantized_stats: Dict[str, Any],
274274
requires_grad: bool = False,
275275
device="cuda",
276+
module: Optional["Linear4bit"] = None,
276277
**kwargs,
277278
) -> "Params4bit":
278279
self = torch.Tensor._make_subclass(cls, data.to(device))
@@ -284,6 +285,10 @@ def from_prequantized(
284285
self.bnb_quantized = True
285286

286287
self.quant_storage = data.dtype
288+
self.module = module
289+
290+
if self.module is not None:
291+
self.module.quant_state = self.quant_state
287292

288293
return self
289294

0 commit comments

Comments
 (0)