1212import torch .fx
1313import torch .nn .functional as F
1414from executorch .backends .arm .common .debug import get_node_debug_info
15- from executorch .backends .arm .common .type import ensure_type
1615from executorch .backends .arm .quantizer import QuantizationConfig
1716from torch ._subclasses import FakeTensor
1817
@@ -511,8 +510,7 @@ def any_or_hardtanh_min_zero(n: Node):
511510 torch .ops .aten .minimum .default ,
512511 torch .ops .aten .maximum .default ,
513512 ):
514- lhs_node = ensure_type (Node , node .args [0 ])
515- shared_qspec = SharedQuantizationSpec ((lhs_node , node ))
513+ shared_qspec = SharedQuantizationSpec ((node .args [0 ], node )) # type: ignore[arg-type]
516514 quant_properties .quant_inputs = [
517515 _QuantProperty (0 , input_act_qspec ),
518516 _QuantProperty (
@@ -522,24 +520,22 @@ def any_or_hardtanh_min_zero(n: Node):
522520 ]
523521 quant_properties .quant_output = _QuantProperty (0 , shared_qspec )
524522 elif node .target in (torch .ops .aten .where .self ,):
525- true_node = ensure_type (Node , node .args [1 ])
526- shared_qspec = SharedQuantizationSpec (true_node )
523+ shared_qspec = SharedQuantizationSpec (node .args [1 ]) # type: ignore[arg-type]
527524 quant_properties .quant_inputs = [
528525 _QuantProperty (1 , shared_qspec ),
529526 _QuantProperty (2 , shared_qspec ),
530527 ]
531528 quant_properties .quant_output = _QuantProperty (0 , shared_qspec )
532529 elif node .target in _one_to_one_shared_input_or_input_act_qspec :
533- input_node = ensure_type (Node , node .args [0 ])
534530 input_qspec = (
535- SharedQuantizationSpec (input_node )
536- if is_output_annotated (input_node )
531+ SharedQuantizationSpec (node . args [ 0 ]) # type: ignore[arg-type]
532+ if is_output_annotated (node . args [ 0 ]) # type: ignore[arg-type]
537533 else input_act_qspec
538534 )
539535 quant_properties .quant_inputs = [_QuantProperty (0 , input_qspec )]
540536 quant_properties .quant_output = _QuantProperty (
541537 0 ,
542- SharedQuantizationSpec ((input_node , node )),
538+ SharedQuantizationSpec ((node . args [ 0 ] , node )), # type: ignore[arg-type]
543539 )
544540 elif node .target in (
545541 torch .ops .aten .cat .default ,
@@ -554,24 +550,26 @@ def any_or_hardtanh_min_zero(n: Node):
554550 )
555551 if len (node .args [0 ]) == 0 :
556552 raise ValueError ("Expected non-empty list for node.args[0]" )
557- inputs = [ ensure_type ( Node , element ) for element in node . args [ 0 ]]
558- shared_qspec = SharedQuantizationSpec ((inputs [0 ], node ))
553+
554+ shared_qspec = SharedQuantizationSpec ((node . args [0 ][ 0 ] , node )) # type: ignore[arg-type]
559555 quant_properties .quant_inputs = [
560556 _QuantProperty (
561557 0 ,
562- [input_act_qspec if n == inputs [0 ] else shared_qspec for n in inputs ],
558+ [
559+ input_act_qspec if n == node .args [0 ][0 ] else shared_qspec # type: ignore[misc]
560+ for n in node .args [0 ]
561+ ],
563562 )
564563 ]
565564 quant_properties .quant_output = _QuantProperty (0 , shared_qspec )
566565 elif node .target in _one_to_one :
567566 quant_properties .quant_inputs = [_QuantProperty (0 , input_act_qspec )]
568567 quant_properties .quant_output = _QuantProperty (0 , output_act_qspec )
569568 elif node .target in _one_to_one_shared_input_qspec :
570- input_node = ensure_type (Node , node .args [0 ])
571569 quant_properties .quant_inputs = [_QuantProperty (0 , input_act_qspec )]
572570 quant_properties .quant_output = _QuantProperty (
573571 0 ,
574- SharedQuantizationSpec ((input_node , node )),
572+ SharedQuantizationSpec ((node . args [ 0 ] , node )), # type: ignore[arg-type]
575573 )
576574 elif node .target in [
577575 torch .ops .aten .eq .Tensor ,
@@ -580,8 +578,7 @@ def any_or_hardtanh_min_zero(n: Node):
580578 torch .ops .aten .le .Tensor ,
581579 torch .ops .aten .lt .Tensor ,
582580 ]:
583- input_node = ensure_type (Node , node .args [0 ])
584- shared_qspec = SharedQuantizationSpec ((input_node , node ))
581+ shared_qspec = SharedQuantizationSpec ((node .args [0 ], node )) # type: ignore[arg-type]
585582 quant_properties .quant_inputs = [
586583 _QuantProperty (0 , input_act_qspec ),
587584 _QuantProperty (
@@ -599,10 +596,9 @@ def any_or_hardtanh_min_zero(n: Node):
599596 quant_properties .quant_inputs = []
600597 quant_properties .quant_output = _QuantProperty (0 , output_act_qspec )
601598 elif node .target in [operator .getitem ]:
602- input_node = ensure_type (Node , node .args [0 ])
603- if not is_output_annotated (input_node ):
599+ if not is_output_annotated (node .args [0 ]): # type: ignore[arg-type]
604600 return None
605- shared_qspec = SharedQuantizationSpec (input_node )
601+ shared_qspec = SharedQuantizationSpec (node . args [ 0 ]) # type: ignore[arg-type]
606602 quant_properties .quant_inputs = [_QuantProperty (0 , shared_qspec )]
607603 quant_properties .quant_output = _QuantProperty (0 , shared_qspec )
608604 else :
0 commit comments