33
33
from compressed_tensors .quantization .quant_scheme import QuantizationScheme
34
34
from compressed_tensors .quantization .utils import is_fp4 , is_kv_cache_quant_scheme
35
35
from compressed_tensors .utils import (
36
+ delete_offload_parameter ,
36
37
disable_hf_hook ,
37
38
get_execution_device ,
38
39
register_offload_parameter ,
@@ -61,10 +62,11 @@ def initialize_module_for_quantization(
61
62
force_zero_point : bool = True ,
62
63
):
63
64
"""
64
- attaches appropriate scales, zero points, and observers to a layer
65
- given its target quantization scheme
65
+ Attaches appropriate scales, zero points, and observers to a layer
66
+ given its target quantization scheme.
66
67
67
- apply to full model with `model.apply(initialize_module_for_quantization)`
68
+ Previously initialized scales and zero points will be removed from
69
+ module if they no longer apply to the scheme
68
70
69
71
:param module: module to set for calibration
70
72
:param scheme: scheme to use for quantization. if None is provided,
@@ -73,6 +75,8 @@ def initialize_module_for_quantization(
73
75
:param force_zero_point: whether to force initialization of a zero point for
74
76
symmetric quantization
75
77
"""
78
+ _clear_all_qparams (module )
79
+
76
80
# TODO: don't initialize parameters when running decompression
77
81
scheme = scheme or getattr (module , "quantization_scheme" , None )
78
82
if scheme is None :
@@ -134,6 +138,29 @@ def is_attention_module(module: Module):
134
138
)
135
139
136
140
141
+ def _clear_all_qparams (
142
+ module : Module ,
143
+ ):
144
+ """
145
+ Clear all previously registered quantization parameters from module
146
+
147
+ :param module: module to clear qparams from
148
+ """
149
+ keys = [KVCacheScaleType .KEY .value , KVCacheScaleType .VALUE .value ] + [
150
+ f"{ base_name } _{ suffix } "
151
+ for base_name in ("input" , "weight" , "output" )
152
+ for suffix in (
153
+ "global_scale" ,
154
+ "scale" ,
155
+ "zero_point" ,
156
+ "g_idx" ,
157
+ )
158
+ ]
159
+ for key in keys :
160
+ if hasattr (module , key ):
161
+ delete_offload_parameter (module , key )
162
+
163
+
137
164
def _initialize_scale_zero_point (
138
165
module : Module ,
139
166
base_name : str ,
0 commit comments