@@ -145,6 +145,86 @@ def get_symmetric_quantization_config(
145145 return quantization_config
146146
147147
148+ @functools .lru_cache
149+ def get_symmetric_a16w8_quantization_config (
150+ is_per_channel : bool = True ,
151+ is_qat : bool = False ,
152+ is_dynamic : bool = False ,
153+ weight_qmin : int = - 127 ,
154+ weight_qmax : int = 127 ,
155+ ):
156+ """
157+ 16A8W quantization config: 16-bit activations, 8-bit weights.
158+
159+ This configuration provides better accuracy than 8A8W while maintaining
160+ reasonable memory usage through 8-bit weights.
161+
162+ Args:
163+ is_per_channel: Whether to use per-channel quantization for weights
164+ is_qat: Whether this is for Quantization Aware Training
165+ is_dynamic: Whether to use dynamic quantization
166+ weight_qmin: Minimum quantization value for weights
167+ weight_qmax: Maximum quantization value for weights
168+
169+ Returns:
170+ QuantizationConfig with 16-bit activations and 8-bit weights
171+ """
172+ extra_args : Dict [str , Any ] = {"eps" : 2 ** - 12 }
173+
174+ # Setup observer/fake-quant for 16-bit activations
175+ if is_qat :
176+ if is_dynamic :
177+ act_observer_or_fake_quant_ctr = FakeQuantize
178+ dynamic_quant_observer = MovingAverageMinMaxObserver .with_args (
179+ averaging_constant = 1
180+ )
181+ extra_args ["observer" ] = dynamic_quant_observer
182+ else :
183+ act_observer_or_fake_quant_ctr = FusedMovingAvgObsFakeQuantize # type: ignore[assignment]
184+ else :
185+ if is_dynamic :
186+ act_observer_or_fake_quant_ctr = PlaceholderObserver # type: ignore[assignment]
187+ else :
188+ # HistogramObserver works well for 16-bit range
189+ act_observer_or_fake_quant_ctr = HistogramObserver # type: ignore[assignment]
190+
191+ # 16-bit activation quantization spec
192+ act_quantization_spec = QuantizationSpec (
193+ dtype = torch .int16 ,
194+ quant_min = torch .iinfo (torch .int16 ).min , # -32768
195+ quant_max = torch .iinfo (torch .int16 ).max , # 32767
196+ qscheme = torch .per_tensor_symmetric ,
197+ is_dynamic = is_dynamic ,
198+ observer_or_fake_quant_ctr = act_observer_or_fake_quant_ctr .with_args (
199+ ** extra_args ,
200+ ),
201+ )
202+
203+ # Instead of reconstructing quantization_config, just clone and update as needed
204+ # Clone the quantization_config from get_symmetric_quantization_config and update activation spec
205+ base_config = get_symmetric_quantization_config (
206+ is_per_channel = is_per_channel ,
207+ is_qat = is_qat ,
208+ is_dynamic = is_dynamic ,
209+ )
210+ # Replace activation quantization spec with 16-bit version
211+ if is_dynamic :
212+ quantization_config = QuantizationConfig (
213+ act_quantization_spec , # 16-bit input activations
214+ None ,
215+ base_config .weight , # 8-bit weights from base config
216+ None ,
217+ )
218+ else :
219+ quantization_config = QuantizationConfig (
220+ act_quantization_spec , # 16-bit input activations
221+ act_quantization_spec , # 16-bit output activations
222+ base_config .weight , # 8-bit weights from base config
223+ None ,
224+ )
225+ return quantization_config
226+
227+
148228NodeFilterType = Callable [[Node ], bool ]
149229"""Type for a Node Filter used by annotators. A Node filter is a function that takes
150230 a Node and returns whether the node should be annotated or not.
0 commit comments