@@ -36,7 +36,6 @@ def forward(self, quantized_param: torch.Tensor) -> torch.Tensor:
3636 Returns:
3737 `torch.Tensor`: The dequantized parameter tensor in the original shape and dtype.
3838 """
39- # print(f"Dequantizing parameter '{self.p_name}'")
4039 return F .dequantize_4bit (quantized_param , self .quant_state )
4140
4241
@@ -54,7 +53,7 @@ def replace_parameter_4bit_prequantized(
5453 quant_state = F .QuantState .from_dict (qs_dict , device = device )
5554
5655 # Apply a parametrization to the module to handle dequantization.
57- P .register_parametrization (module , param_name , Bnb4bitParametrization (quant_state , p_name = param_name ), unsafe = True )
56+ P .register_parametrization (module , param_name , Bnb4bitParametrization (quant_state ), unsafe = True )
5857
5958 # Next, register state dict hook for saving.
6059 module .register_state_dict_post_hook (
@@ -126,7 +125,7 @@ def replace_parameter_4bit(
126125 del original_param
127126
128127 # Apply a parametrization to the module to handle dequantization.
129- P .register_parametrization (module , param_name , Bnb4bitParametrization (quant_state , p_name = param_name ), unsafe = True )
128+ P .register_parametrization (module , param_name , Bnb4bitParametrization (quant_state ), unsafe = True )
130129
131130 # Next, register state dict hook for saving.
132131 module .register_state_dict_post_hook (
0 commit comments