@@ -21,9 +21,10 @@ class Bnb4bitParametrization(nn.Module):
2121 The quantization state containing the necessary information for dequantization.
2222 """
2323
24- def __init__ (self , quant_state : F .QuantState ):
24+ def __init__ (self , quant_state : F .QuantState , p_name = "unknown" ):
2525 super ().__init__ ()
2626 self .quant_state = quant_state
27+ self .p_name = p_name
2728
2829 def forward (self , quantized_param : torch .Tensor ) -> torch .Tensor :
2930 """
@@ -35,9 +36,35 @@ def forward(self, quantized_param: torch.Tensor) -> torch.Tensor:
3536 Returns:
3637 `torch.Tensor`: The dequantized parameter tensor in the original shape and dtype.
3738 """
39+ # print(f"Dequantizing parameter '{self.p_name}'")
3840 return F .dequantize_4bit (quantized_param , self .quant_state )
3941
4042
43+ def replace_parameter_4bit_prequantized (
44+ module : nn .Module , param_name : str , qs_dict : dict [str , Any ], device : torch .device
45+ ):
46+ if not hasattr (module , param_name ):
47+ raise AttributeError (f"Module does not have parameter '{ param_name } '" )
48+
49+ original_param = getattr (module , param_name )
50+
51+ if not isinstance (original_param , nn .Parameter ):
52+ raise TypeError (f"Parameter '{ param_name } ' is not an instance of nn.Parameter" )
53+
54+ quant_state = F .QuantState .from_dict (qs_dict , device = device )
55+
56+ # Apply a parametrization to the module to handle dequantization.
57+ P .register_parametrization (module , param_name , Bnb4bitParametrization (quant_state , p_name = param_name ), unsafe = True )
58+
59+ # Next, register state dict hook for saving.
60+ module .register_state_dict_post_hook (
61+ partial (
62+ _parametrized_state_dict_post_hook ,
63+ param_name = param_name ,
64+ )
65+ )
66+
67+
4168def replace_parameter_4bit (
4269 module : nn .Module ,
4370 param_name : str ,
@@ -99,7 +126,7 @@ def replace_parameter_4bit(
99126 del original_param
100127
101128 # Apply a parametrization to the module to handle dequantization.
102- P .register_parametrization (module , param_name , Bnb4bitParametrization (quant_state ), unsafe = True )
129+ P .register_parametrization (module , param_name , Bnb4bitParametrization (quant_state , p_name = param_name ), unsafe = True )
103130
104131 # Next, register state dict hook for saving.
105132 module .register_state_dict_post_hook (
0 commit comments