@@ -145,6 +145,86 @@ def get_symmetric_quantization_config(
145
145
return quantization_config
146
146
147
147
148
+ @functools .lru_cache
149
+ def get_symmetric_a16w8_quantization_config (
150
+ is_per_channel : bool = True ,
151
+ is_qat : bool = False ,
152
+ is_dynamic : bool = False ,
153
+ weight_qmin : int = - 127 ,
154
+ weight_qmax : int = 127 ,
155
+ ):
156
+ """
157
+ 16A8W quantization config: 16-bit activations, 8-bit weights.
158
+
159
+ This configuration provides better accuracy than 8A8W while maintaining
160
+ reasonable memory usage through 8-bit weights.
161
+
162
+ Args:
163
+ is_per_channel: Whether to use per-channel quantization for weights
164
+ is_qat: Whether this is for Quantization Aware Training
165
+ is_dynamic: Whether to use dynamic quantization
166
+ weight_qmin: Minimum quantization value for weights
167
+ weight_qmax: Maximum quantization value for weights
168
+
169
+ Returns:
170
+ QuantizationConfig with 16-bit activations and 8-bit weights
171
+ """
172
+ extra_args : Dict [str , Any ] = {"eps" : 2 ** - 12 }
173
+
174
+ # Setup observer/fake-quant for 16-bit activations
175
+ if is_qat :
176
+ if is_dynamic :
177
+ act_observer_or_fake_quant_ctr = FakeQuantize
178
+ dynamic_quant_observer = MovingAverageMinMaxObserver .with_args (
179
+ averaging_constant = 1
180
+ )
181
+ extra_args ["observer" ] = dynamic_quant_observer
182
+ else :
183
+ act_observer_or_fake_quant_ctr = FusedMovingAvgObsFakeQuantize # type: ignore[assignment]
184
+ else :
185
+ if is_dynamic :
186
+ act_observer_or_fake_quant_ctr = PlaceholderObserver # type: ignore[assignment]
187
+ else :
188
+ # HistogramObserver works well for 16-bit range
189
+ act_observer_or_fake_quant_ctr = HistogramObserver # type: ignore[assignment]
190
+
191
+ # 16-bit activation quantization spec
192
+ act_quantization_spec = QuantizationSpec (
193
+ dtype = torch .int16 ,
194
+ quant_min = torch .iinfo (torch .int16 ).min , # -32768
195
+ quant_max = torch .iinfo (torch .int16 ).max , # 32767
196
+ qscheme = torch .per_tensor_symmetric ,
197
+ is_dynamic = is_dynamic ,
198
+ observer_or_fake_quant_ctr = act_observer_or_fake_quant_ctr .with_args (
199
+ ** extra_args ,
200
+ ),
201
+ )
202
+
203
+ # Instead of reconstructing quantization_config, just clone and update as needed
204
+ # Clone the quantization_config from get_symmetric_quantization_config and update activation spec
205
+ base_config = get_symmetric_quantization_config (
206
+ is_per_channel = is_per_channel ,
207
+ is_qat = is_qat ,
208
+ is_dynamic = is_dynamic ,
209
+ )
210
+ # Replace activation quantization spec with 16-bit version
211
+ if is_dynamic :
212
+ quantization_config = QuantizationConfig (
213
+ act_quantization_spec , # 16-bit input activations
214
+ None ,
215
+ base_config .weight , # 8-bit weights from base config
216
+ None ,
217
+ )
218
+ else :
219
+ quantization_config = QuantizationConfig (
220
+ act_quantization_spec , # 16-bit input activations
221
+ act_quantization_spec , # 16-bit output activations
222
+ base_config .weight , # 8-bit weights from base config
223
+ None ,
224
+ )
225
+ return quantization_config
226
+
227
+
148
228
NodeFilterType = Callable [[Node ], bool ]
149
229
"""Type for a Node Filter used by annotators. A Node filter is a function that takes
150
230
a Node and returns whether the node should be annotated or not.
0 commit comments