5454)
5555from torch import fx
5656from torch .ao .quantization .quantizer .utils import _annotate_output_qspec
57- from torchao .quantization .pt2e import HistogramObserver , MinMaxObserver
57+ from torchao .quantization .pt2e import (
58+ FakeQuantize ,
59+ FusedMovingAvgObsFakeQuantize ,
60+ HistogramObserver ,
61+ MinMaxObserver ,
62+ MovingAverageMinMaxObserver ,
63+ )
5864from torchao .quantization .pt2e .quantizer import (
5965 ComposableQuantizer ,
6066 DerivedQuantizationSpec ,
@@ -154,78 +160,120 @@ def get_supported_operators(cls) -> list[OperatorConfig]:
154160
155161
156162# Quantization Specification used by Neutron NPU
157- act_qspec = QuantizationSpec (
158- dtype = torch .int8 ,
159- quant_min = - 128 ,
160- quant_max = 127 ,
161- qscheme = torch .per_tensor_affine ,
162- is_dynamic = False ,
163- observer_or_fake_quant_ctr = HistogramObserver .with_args (eps = 2 ** - 12 ),
164- )
165-
166- wgt_qspec = QuantizationSpec (
167- dtype = torch .int8 ,
168- quant_min = - 127 ,
169- quant_max = 127 ,
170- qscheme = torch .per_tensor_symmetric ,
171- is_dynamic = False ,
172- observer_or_fake_quant_ctr = MinMaxObserver ,
173- ch_axis = 0 ,
174- )
163+ def act_qspec (is_qat : bool ):
164+ eps = 2 ** - 12
165+ observer_or_fake_quant_ctr = (
166+ FusedMovingAvgObsFakeQuantize .with_args (
167+ observer = MovingAverageMinMaxObserver , eps = eps
168+ )
169+ if is_qat
170+ else HistogramObserver .with_args (eps = eps )
171+ )
172+
173+ return QuantizationSpec (
174+ dtype = torch .int8 ,
175+ quant_min = - 128 ,
176+ quant_max = 127 ,
177+ qscheme = torch .per_tensor_affine ,
178+ is_dynamic = False ,
179+ observer_or_fake_quant_ctr = observer_or_fake_quant_ctr ,
180+ )
181+
182+
183+ def wgt_qspec (is_qat : bool ):
184+ observer_or_fake_quant_ctr = (
185+ FakeQuantize .with_args (observer = MovingAverageMinMaxObserver )
186+ if is_qat
187+ else MinMaxObserver
188+ )
189+
190+ return QuantizationSpec (
191+ dtype = torch .int8 ,
192+ quant_min = - 127 ,
193+ quant_max = 127 ,
194+ qscheme = torch .per_tensor_symmetric ,
195+ is_dynamic = False ,
196+ observer_or_fake_quant_ctr = observer_or_fake_quant_ctr ,
197+ ch_axis = 0 ,
198+ )
199+
200+
201+ def wgt_fc_qspec (is_qat : bool ):
202+ observer_or_fake_quant_ctr = (
203+ FakeQuantize .with_args (observer = MovingAverageMinMaxObserver )
204+ if is_qat
205+ else MinMaxObserver
206+ )
207+
208+ return QuantizationSpec (
209+ dtype = torch .int8 ,
210+ quant_min = - 127 ,
211+ quant_max = 127 ,
212+ qscheme = torch .per_tensor_symmetric ,
213+ is_dynamic = False ,
214+ observer_or_fake_quant_ctr = observer_or_fake_quant_ctr ,
215+ )
175216
176- wgt_fc_qspec = QuantizationSpec (
177- dtype = torch .int8 ,
178- quant_min = - 127 ,
179- quant_max = 127 ,
180- qscheme = torch .per_tensor_symmetric ,
181- is_dynamic = False ,
182- observer_or_fake_quant_ctr = MinMaxObserver ,
183- )
184217
185218# Is set by the *PatternQuantizer directly.
186219bias_qspec = None
187220
188221
189222class NeutronQuantizer (ComposableQuantizer ):
190- def __init__ (self , neutron_target_spec : NeutronTargetSpec ):
223+ def __init__ (self , neutron_target_spec : NeutronTargetSpec , is_qat : bool = False ):
191224 self .neutron_target_spec = neutron_target_spec
192- static_qconfig = QuantizationConfig (act_qspec , act_qspec , wgt_qspec , None )
193- static_fc_qconfig = QuantizationConfig (act_qspec , act_qspec , wgt_fc_qspec , None )
225+ self .is_qat = is_qat
226+
227+ static_qconfig = QuantizationConfig (
228+ act_qspec (is_qat = is_qat ),
229+ act_qspec (is_qat = is_qat ),
230+ wgt_qspec (is_qat = is_qat ),
231+ None ,
232+ )
233+ static_fc_qconfig = QuantizationConfig (
234+ act_qspec (is_qat = is_qat ),
235+ act_qspec (is_qat = is_qat ),
236+ wgt_fc_qspec (is_qat = is_qat ),
237+ None ,
238+ )
239+
240+ OpQuantizer = NeutronAtenQuantizer
194241 super ().__init__ (
195242 [
196- NeutronAtenQuantizer (AbsPattern (), static_qconfig ),
197- NeutronAtenQuantizer (AdaptiveAvgPoolPattern (), static_qconfig ),
198- NeutronAtenQuantizer (AddTensorPattern (), static_qconfig ),
199- NeutronAtenQuantizer (AddmmPattern (self ), static_fc_qconfig ),
200- NeutronAtenQuantizer (AvgPoolPattern (), static_qconfig ),
201- NeutronAtenQuantizer (CatPattern (), static_qconfig ),
202- NeutronAtenQuantizer (Conv1dPattern (), static_qconfig ),
203- NeutronAtenQuantizer (Conv2dPattern (self ), static_qconfig ),
204- NeutronAtenQuantizer (ConvTranspose2dPattern (), static_qconfig ),
205- NeutronAtenQuantizer (DropoutPattern (), static_qconfig ),
206- NeutronAtenQuantizer (FlattenPattern (), static_qconfig ),
207- NeutronAtenQuantizer (HardTanhPattern (), static_qconfig ),
208- NeutronAtenQuantizer (HardTanhInPlacePattern (), static_qconfig ),
209- NeutronAtenQuantizer (LinearPattern (self ), static_fc_qconfig ),
210- NeutronAtenQuantizer (MaxPoolPattern (), static_qconfig ),
211- NeutronAtenQuantizer (MeanDimPattern (), static_qconfig ),
212- NeutronAtenQuantizer (MmPattern (self ), static_qconfig ),
213- NeutronAtenQuantizer (MulTensorPattern (), static_qconfig ),
214- NeutronAtenQuantizer (PadPattern (), static_qconfig ),
215- NeutronAtenQuantizer (PermutePattern (), static_qconfig ),
216- NeutronAtenQuantizer (ReluPattern (), static_qconfig ),
217- NeutronAtenQuantizer (ReluInPlacePattern (), static_qconfig ),
218- NeutronAtenQuantizer (ReshapePattern (), static_qconfig ),
219- NeutronAtenQuantizer (SigmoidPattern (), static_qconfig ),
220- NeutronAtenQuantizer (SliceTensorPattern (), static_qconfig ),
221- NeutronAtenQuantizer (SoftMaxPattern (), static_qconfig ),
222- NeutronAtenQuantizer (SubTensorPattern (), static_qconfig ),
223- NeutronAtenQuantizer (TanhPattern (), static_qconfig ),
224- NeutronAtenQuantizer (TanhInPlacePattern (), static_qconfig ),
225- NeutronAtenQuantizer (TransposeIntPattern (), static_qconfig ),
226- NeutronAtenQuantizer (ViewPattern (), static_qconfig ),
243+ OpQuantizer (AbsPattern (is_qat = is_qat ), static_qconfig ),
244+ OpQuantizer (AdaptiveAvgPoolPattern (is_qat = is_qat ), static_qconfig ),
245+ OpQuantizer (AddTensorPattern (is_qat = is_qat ), static_qconfig ),
246+ OpQuantizer (AddmmPattern (self , is_qat = is_qat ), static_fc_qconfig ),
247+ OpQuantizer (AvgPoolPattern (is_qat = is_qat ), static_qconfig ),
248+ OpQuantizer (CatPattern (is_qat = is_qat ), static_qconfig ),
249+ OpQuantizer (Conv1dPattern (is_qat = is_qat ), static_qconfig ),
250+ OpQuantizer (Conv2dPattern (self , is_qat = is_qat ), static_qconfig ),
251+ OpQuantizer (ConvTranspose2dPattern (is_qat = is_qat ), static_qconfig ),
252+ OpQuantizer (DropoutPattern (is_qat = is_qat ), static_qconfig ),
253+ OpQuantizer (FlattenPattern (is_qat = is_qat ), static_qconfig ),
254+ OpQuantizer (HardTanhPattern (is_qat = is_qat ), static_qconfig ),
255+ OpQuantizer (HardTanhInPlacePattern (is_qat = is_qat ), static_qconfig ),
256+ OpQuantizer (LinearPattern (self , is_qat = is_qat ), static_fc_qconfig ),
257+ OpQuantizer (MaxPoolPattern (is_qat = is_qat ), static_qconfig ),
258+ OpQuantizer (MeanDimPattern (is_qat = is_qat ), static_qconfig ),
259+ OpQuantizer (MmPattern (self , is_qat = is_qat ), static_qconfig ),
260+ OpQuantizer (MulTensorPattern (is_qat = is_qat ), static_qconfig ),
261+ OpQuantizer (PadPattern (is_qat = is_qat ), static_qconfig ),
262+ OpQuantizer (PermutePattern (is_qat = is_qat ), static_qconfig ),
263+ OpQuantizer (ReluPattern (is_qat = is_qat ), static_qconfig ),
264+ OpQuantizer (ReluInPlacePattern (is_qat = is_qat ), static_qconfig ),
265+ OpQuantizer (ReshapePattern (is_qat = is_qat ), static_qconfig ),
266+ OpQuantizer (SigmoidPattern (is_qat = is_qat ), static_qconfig ),
267+ OpQuantizer (SliceTensorPattern (is_qat = is_qat ), static_qconfig ),
268+ OpQuantizer (SoftMaxPattern (is_qat = is_qat ), static_qconfig ),
269+ OpQuantizer (SubTensorPattern (is_qat = is_qat ), static_qconfig ),
270+ OpQuantizer (TanhPattern (is_qat = is_qat ), static_qconfig ),
271+ OpQuantizer (TanhInPlacePattern (is_qat = is_qat ), static_qconfig ),
272+ OpQuantizer (TransposeIntPattern (is_qat = is_qat ), static_qconfig ),
273+ OpQuantizer (ViewPattern (is_qat = is_qat ), static_qconfig ),
227274 ]
228275 )
276+
229277 # Mapping ops defined in quantizer partition types to its quantizer
230278 self .op_to_quantizer = {
231279 pt : q for q in self .quantizers for pt in q .pattern .partition_types ()
@@ -235,7 +283,9 @@ def __init__(self, neutron_target_spec: NeutronTargetSpec):
235283 pt : False for q in self .quantizers for pt in q .pattern .partition_types ()
236284 }
237285 self .cluster_quantizers = [
238- NeutronAtenQuantizer (ActivationsConcatClusterPattern (self ), static_qconfig )
286+ NeutronAtenQuantizer (
287+ ActivationsConcatClusterPattern (self , is_qat = is_qat ), static_qconfig
288+ )
239289 ]
240290
241291 def transform_for_annotation (
@@ -288,7 +338,7 @@ def _annotate_inputs(self, model: fx.GraphModule):
288338 continue
289339
290340 if node .op == "placeholder" and len (node .users ) > 0 :
291- _annotate_output_qspec (node , act_qspec )
341+ _annotate_output_qspec (node , act_qspec ( self . is_qat ) )
292342 self ._mark_input_node_as_annotated (node )
293343
294344 def validate (self , model : torch .fx .GraphModule ) -> None :
0 commit comments