@@ -184,11 +184,15 @@ def _replace_param(
184
184
if param .device .type == "meta" :
185
185
if isinstance (param , bnb .nn .Params4bit ):
186
186
return bnb .nn .Params4bit (
187
- data ,
188
- requires_grad = data .requires_grad ,
189
- quant_state = quant_state ,
190
- compress_statistics = param .compress_statistics ,
191
- quant_type = param .quant_type ,
187
+ data = data ,
188
+ requires_grad = data .requires_grad ,
189
+ quant_state = quant_state ,
190
+ blocksize = param .blocksize ,
191
+ compress_statistics = param .compress_statistics ,
192
+ quant_type = param .quant_type ,
193
+ quant_storage = param .quant_storage ,
194
+ module = param .module ,
195
+ bnb_quantized = param .bnb_quantized
192
196
)
193
197
return torch .nn .Parameter (data , requires_grad = data .requires_grad )
194
198
param .data = data
@@ -338,6 +342,7 @@ def quantize(
338
342
blocksize = params4bit .blocksize ,
339
343
compress_statistics = params4bit .compress_statistics ,
340
344
quant_type = params4bit .quant_type ,
345
+ quant_storage = params4bit .quant_storage ,
341
346
)
342
347
return _replace_param (params4bit , w_4bit , quant_state )
343
348
0 commit comments