@@ -184,14 +184,29 @@ def get_default_8bit_qnn_ptq_config(
184184) -> QuantizationConfig :
185185 extra_args : Dict [str , Any ] = {"eps" : 2 ** - 12 }
186186
187- act_quantization_spec = QuantizationSpec (
188- dtype = torch .uint8 ,
189- qscheme = (
190- torch .per_tensor_symmetric if act_symmetric else torch .per_tensor_affine
191- ),
192- ch_axis = 0 ,
193- observer_or_fake_quant_ctr = act_observer .with_args (** extra_args ),
194- )
187+ if act_symmetric :
188+ # If zero_point is 128, htp can do optimizations.
189+ # If we keep quant_min and quant_max none, observer will default use 128 as zero_point.
190+ # If we provide uint8 quant_min/max, it will use 127 as zero_point, which is undesired.
191+ act_quantization_spec = QuantizationSpec (
192+ dtype = torch .uint8 ,
193+ qscheme = torch .per_tensor_symmetric ,
194+ ch_axis = 0 ,
195+ observer_or_fake_quant_ctr = act_observer .with_args (** extra_args ),
196+ )
197+ else :
198+ # PyTorch will remove redundant observers based on attributes such as:
199+ # dtype, quant_min, quant_max, ch_axis, etc.
200+ # Providing values like quant_min and quant_max can help observers compare
201+ # and further reduce the number of observers.
202+ act_quantization_spec = QuantizationSpec (
203+ dtype = torch .uint8 ,
204+ quant_min = torch .iinfo (torch .uint8 ).min ,
205+ quant_max = torch .iinfo (torch .uint8 ).max ,
206+ qscheme = torch .per_tensor_affine ,
207+ ch_axis = 0 ,
208+ observer_or_fake_quant_ctr = act_observer .with_args (** extra_args ),
209+ )
195210
196211 weight_quantization_spec = QuantizationSpec (
197212 dtype = torch .int8 ,
@@ -364,6 +379,7 @@ def get_ptq_per_channel_quant_config(
364379 quant_min = torch .iinfo (act_dtype ).min ,
365380 quant_max = torch .iinfo (act_dtype ).max ,
366381 qscheme = torch .per_tensor_affine ,
382+ ch_axis = 0 ,
367383 observer_or_fake_quant_ctr = MovingAverageMinMaxObserver .with_args (** extra_args ),
368384 )
369385
0 commit comments