@@ -237,9 +237,8 @@ def _match_pattern(
237237 torch .ops .aten .clamp .Tensor ,
238238]
239239
240- # Operators that can inherit the quantization specs from its parent node
241- # as SharedQuantizationSpec.
242- _parent_shared_qspec = [
240+ _one_to_one_shared_input_or_input_act_qspec = [
241+ torch .ops .aten .clone .default ,
243242 torch .ops .aten .hardtanh .default ,
244243 torch .ops .aten .hardtanh_ .default ,
245244 torch .ops .aten .relu .default ,
@@ -254,11 +253,6 @@ def _match_pattern(
254253 torch .ops .aten .flatten .using_ints ,
255254 torch .ops .aten .dropout .default ,
256255 torch .ops .aten .dropout_ .default ,
257- torch .ops .aten .where ,
258- operator .getitem ,
259- ]
260-
261- _one_to_one_shared_input_or_input_act_qspec = [
262256 torch .ops .aten .adaptive_avg_pool2d .default ,
263257 torch .ops .aten .alias_copy .default ,
264258]
@@ -404,6 +398,9 @@ def any_or_hardtanh_min_zero(n: Node):
404398 ]
405399 quant_properties .quant_output = _QuantProperty (0 , shared_qspec ) # type: ignore[arg-type]
406400 elif node .target in _one_to_one_shared_input_or_input_act_qspec :
401+ if not isinstance (node .args [0 ], Node ):
402+ return None
403+
407404 input_qspec = (
408405 SharedQuantizationSpec (node .args [0 ]) # type: ignore[arg-type]
409406 if is_output_annotated (node .args [0 ]) # type: ignore
@@ -458,19 +455,16 @@ def any_or_hardtanh_min_zero(n: Node):
458455 ),
459456 ]
460457 quant_properties .quant_output = None
461- elif node .target in _parent_shared_qspec :
462- if not isinstance ( node . args [ 0 ], Node ):
463- return None
464-
465- if not is_output_annotated (node .args [0 ]): # type: ignore[attr-defined]
458+ elif node .target in [ torch . ops . aten . scalar_tensor . default ] :
459+ quant_properties . quant_inputs = []
460+ quant_properties . quant_output = _QuantProperty ( 0 , output_act_qspec )
461+ elif node . target in [ operator . getitem ]:
462+ if not is_output_annotated (node .args [0 ]): # type: ignore[attr-defined, arg-type ]
466463 return None
467464
468- shared_qspec = SharedQuantizationSpec (node .args [0 ])
465+ shared_qspec = SharedQuantizationSpec (node .args [0 ]) # type: ignore[arg-type]
469466 quant_properties .quant_inputs = [_QuantProperty (0 , shared_qspec )] # type: ignore[arg-type]
470467 quant_properties .quant_output = _QuantProperty (0 , shared_qspec ) # type: ignore[arg-type]
471- elif node .target in [torch .ops .aten .scalar_tensor .default ]:
472- quant_properties .quant_inputs = []
473- quant_properties .quant_output = _QuantProperty (0 , output_act_qspec )
474468 else :
475469 return None
476470
0 commit comments