@@ -229,14 +229,29 @@ def get_default_8bit_qnn_ptq_config(
229229) -> QuantizationConfig :
230230 extra_args : Dict [str , Any ] = {"eps" : 2 ** - 12 }
231231
232- act_quantization_spec = QuantizationSpec (
233- dtype = torch .uint8 ,
234- qscheme = (
235- torch .per_tensor_symmetric if act_symmetric else torch .per_tensor_affine
236- ),
237- ch_axis = 0 ,
238- observer_or_fake_quant_ctr = act_observer .with_args (** extra_args ),
239- )
232+ if act_symmetric :
233+ # If zero_point is 128, htp can do optimizations.
234+ # If we keep quant_min and quant_max none, observer will default use 128 as zero_point.
235+ # If we provide uint8 quant_min/max, it will use 127 as zero_point, which is undesired.
236+ act_quantization_spec = QuantizationSpec (
237+ dtype = torch .uint8 ,
238+ qscheme = torch .per_tensor_symmetric ,
239+ ch_axis = 0 ,
240+ observer_or_fake_quant_ctr = act_observer .with_args (** extra_args ),
241+ )
242+ else :
243+ # PyTorch will remove redundant observers based on attributes such as:
244+ # dtype, quant_min, quant_max, ch_axis, etc.
245+ # Providing values like quant_min and quant_max can help observers compare
246+ # and further reduce the number of observers.
247+ act_quantization_spec = QuantizationSpec (
248+ dtype = torch .uint8 ,
249+ quant_min = torch .iinfo (torch .uint8 ).min ,
250+ quant_max = torch .iinfo (torch .uint8 ).max ,
251+ qscheme = torch .per_tensor_affine ,
252+ ch_axis = 0 ,
253+ observer_or_fake_quant_ctr = act_observer .with_args (** extra_args ),
254+ )
240255
241256 weight_quantization_spec = QuantizationSpec (
242257 dtype = torch .int8 ,
@@ -409,6 +424,7 @@ def get_ptq_per_channel_quant_config(
409424 quant_min = torch .iinfo (act_dtype ).min ,
410425 quant_max = torch .iinfo (act_dtype ).max ,
411426 qscheme = torch .per_tensor_affine ,
427+ ch_axis = 0 ,
412428 observer_or_fake_quant_ctr = MovingAverageMinMaxObserver .with_args (** extra_args ),
413429 )
414430
0 commit comments