47
47
__all__ = [
48
48
"initialize_module_for_quantization" ,
49
49
"is_attention_module" ,
50
+ "initialize_qparams" ,
50
51
]
51
52
52
53
@@ -100,7 +101,7 @@ def initialize_module_for_quantization(
100
101
return
101
102
102
103
if scheme .input_activations is not None :
103
- _initialize_scale_zero_point (
104
+ initialize_qparams (
104
105
module ,
105
106
"input" ,
106
107
scheme .input_activations ,
@@ -110,7 +111,7 @@ def initialize_module_for_quantization(
110
111
)
111
112
112
113
if scheme .weights is not None :
113
- _initialize_scale_zero_point (
114
+ initialize_qparams (
114
115
module ,
115
116
"weight" ,
116
117
scheme .weights ,
@@ -121,7 +122,7 @@ def initialize_module_for_quantization(
121
122
122
123
output_is_kv_cache = is_kv_cache_quant_scheme (scheme )
123
124
if scheme .output_activations is not None and not output_is_kv_cache :
124
- _initialize_scale_zero_point (
125
+ initialize_qparams (
125
126
module ,
126
127
"output" ,
127
128
scheme .output_activations ,
@@ -147,14 +148,29 @@ def is_attention_module(module: Module):
147
148
)
148
149
149
150
150
- def _initialize_scale_zero_point (
151
+ def initialize_qparams (
151
152
module : Module ,
152
153
base_name : str ,
153
154
quantization_args : QuantizationArgs ,
154
155
observed_shape : Tuple [int ],
155
156
observed_dtype : torch .dtype ,
156
157
force_zero_point : bool = True ,
157
158
):
159
+ """
160
+ Initialize quantization parameters for a given basename according to the passed
161
+ quantization args. The shape and dtype of the observed weight/activation must also
162
+ be provided.
163
+
164
+ Scales will always be initialized. Global scales are initialized depending on args.
165
+ Zero points will be initialized if not symmetric or if `force_zero_point` is True.
166
+
167
+ :param module: module to register qparams to
168
+ :param base_name: base name of qparams, for example "input", "weight", "k", "v"
169
+ :param quantization_args: arguments for quantization
170
+ :param observed_shape: last (right-most) known dimensions of the observed weight/act
171
+ :param observed_dtype: dtype of the observed weight/actt
172
+ :param force_zero_point: force the zero_point parameter to be initialized
173
+ """
158
174
strategy = quantization_args .strategy
159
175
dynamic = quantization_args .dynamic
160
176
actorder = quantization_args .actorder
0 commit comments