Skip to content

Commit 17273b4

Browse files
author
Ali Alshaarawy
committed
add quant_storage param
1 parent 7365d1e commit 17273b4

File tree

1 file changed

+10
-5
lines changed

1 file changed

+10
-5
lines changed

src/lightning/fabric/plugins/precision/bitsandbytes.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -184,11 +184,15 @@ def _replace_param(
184184
if param.device.type == "meta":
185185
if isinstance(param, bnb.nn.Params4bit):
186186
return bnb.nn.Params4bit(
187-
data,
188-
requires_grad=data.requires_grad,
189-
quant_state=quant_state,
190-
compress_statistics=param.compress_statistics,
191-
quant_type=param.quant_type,
187+
data = data,
188+
requires_grad = data.requires_grad,
189+
quant_state = quant_state,
190+
blocksize = param.blocksize,
191+
compress_statistics = param.compress_statistics,
192+
quant_type = param.quant_type,
193+
quant_storage = param.quant_storage,
194+
module = param.module,
195+
bnb_quantized = param.bnb_quantized
192196
)
193197
return torch.nn.Parameter(data, requires_grad=data.requires_grad)
194198
param.data = data
@@ -338,6 +342,7 @@ def quantize(
338342
blocksize=params4bit.blocksize,
339343
compress_statistics=params4bit.compress_statistics,
340344
quant_type=params4bit.quant_type,
345+
quant_storage=params4bit.quant_storage,
341346
)
342347
return _replace_param(params4bit, w_4bit, quant_state)
343348

0 commit comments

Comments
 (0)