@@ -237,9 +237,8 @@ def _match_pattern(
237
237
torch .ops .aten .clamp .Tensor ,
238
238
]
239
239
240
- # Operators that can inherit the quantization specs from its parent node
241
- # as SharedQuantizationSpec.
242
- _parent_shared_qspec = [
240
+ _one_to_one_shared_input_or_input_act_qspec = [
241
+ torch .ops .aten .clone .default ,
243
242
torch .ops .aten .hardtanh .default ,
244
243
torch .ops .aten .hardtanh_ .default ,
245
244
torch .ops .aten .relu .default ,
@@ -254,11 +253,6 @@ def _match_pattern(
254
253
torch .ops .aten .flatten .using_ints ,
255
254
torch .ops .aten .dropout .default ,
256
255
torch .ops .aten .dropout_ .default ,
257
- torch .ops .aten .where ,
258
- operator .getitem ,
259
- ]
260
-
261
- _one_to_one_shared_input_or_input_act_qspec = [
262
256
torch .ops .aten .adaptive_avg_pool2d .default ,
263
257
torch .ops .aten .alias_copy .default ,
264
258
]
@@ -404,6 +398,9 @@ def any_or_hardtanh_min_zero(n: Node):
404
398
]
405
399
quant_properties .quant_output = _QuantProperty (0 , shared_qspec ) # type: ignore[arg-type]
406
400
elif node .target in _one_to_one_shared_input_or_input_act_qspec :
401
+ if not isinstance (node .args [0 ], Node ):
402
+ return None
403
+
407
404
input_qspec = (
408
405
SharedQuantizationSpec (node .args [0 ]) # type: ignore[arg-type]
409
406
if is_output_annotated (node .args [0 ]) # type: ignore
@@ -458,19 +455,16 @@ def any_or_hardtanh_min_zero(n: Node):
458
455
),
459
456
]
460
457
quant_properties .quant_output = None
461
- elif node .target in _parent_shared_qspec :
462
- if not isinstance ( node . args [ 0 ], Node ):
463
- return None
464
-
465
- if not is_output_annotated (node .args [0 ]): # type: ignore[attr-defined]
458
+ elif node .target in [ torch . ops . aten . scalar_tensor . default ] :
459
+ quant_properties . quant_inputs = []
460
+ quant_properties . quant_output = _QuantProperty ( 0 , output_act_qspec )
461
+ elif node . target in [ operator . getitem ]:
462
+ if not is_output_annotated (node .args [0 ]): # type: ignore[attr-defined, arg-type ]
466
463
return None
467
464
468
- shared_qspec = SharedQuantizationSpec (node .args [0 ])
465
+ shared_qspec = SharedQuantizationSpec (node .args [0 ]) # type: ignore[arg-type]
469
466
quant_properties .quant_inputs = [_QuantProperty (0 , shared_qspec )] # type: ignore[arg-type]
470
467
quant_properties .quant_output = _QuantProperty (0 , shared_qspec ) # type: ignore[arg-type]
471
- elif node .target in [torch .ops .aten .scalar_tensor .default ]:
472
- quant_properties .quant_inputs = []
473
- quant_properties .quant_output = _QuantProperty (0 , output_act_qspec )
474
468
else :
475
469
return None
476
470
0 commit comments