1616 register_node_visitor ,
1717)
1818from executorch .backends .arm .tosa_mapping import TosaArg
19+ from executorch .backends .arm .tosa_specification import TosaSpecification
1920from serializer .tosa_serializer import TosaOp
2021from torch .fx import Node
2122
2223
2324@register_node_visitor
24- class AddVisitor (NodeVisitor ):
25+ class SumVisitor_080_BI (NodeVisitor ):
2526 target = "aten.sum.dim_IntList"
2627
28+ tosa_specs = [
29+ TosaSpecification .create_from_string ("TOSA-0.80.0+BI" ),
30+ ]
31+
2732 def __init__ (self , * args ):
2833 super ().__init__ (* args )
2934
@@ -35,64 +40,89 @@ def define_node(
3540 output : TosaArg ,
3641 is_quant_node : bool ,
3742 ) -> None :
38- input_node = inputs [0 ]
39- input_shape = list (input_node .shape )
43+ input_shape = list (inputs [0 ].shape )
4044 dim_list = cast (list [int ], inputs [1 ].special )
41- dim_list = [dim % len (input_node . shape ) for dim in dim_list ]
45+ dim_list = [dim % len (input_shape ) for dim in dim_list ]
4246 keep_dim = cast (bool , inputs [2 ].number if len (inputs ) > 2 else False )
4347 assert keep_dim , "This case should be handled by InsertSqueezeAfterSumPass"
4448
45- if is_quant_node :
49+ # Rescale input to 32 bit
50+ rescaled_inputs , scale = tqutils .insert_rescale_ops_to_int32 (
51+ tosa_graph ,
52+ [inputs [0 ]],
53+ node ,
54+ )
55+
56+ prev_node = rescaled_inputs [0 ]
57+ reduced_shape = input_shape
58+
59+ # Reduce all dims in dim_list one-by-one.
60+ for dim in dim_list :
61+ # When reduced, the size of the dim becomes 1.
62+ reduced_shape [dim ] = 1
63+
64+ attr = ts .TosaSerializerAttribute ()
65+ attr .AxisAttribute (inputs [0 ].dim_order .index (dim ))
66+
67+ next_node = tosa_graph .addIntermediate (
68+ tutils .tosa_shape (reduced_shape , inputs [0 ].dim_order ),
69+ dtype = ts .DType .INT32 ,
70+ )
71+
72+ tosa_graph .addOperator (
73+ TosaOp .Op ().REDUCE_SUM , [prev_node .name ], [next_node .name ], attr
74+ )
75+
76+ prev_node = next_node
77+ tqutils .insert_rescale_op_to_int8 (tosa_graph , prev_node , scale , node )
78+
79+
80+ @register_node_visitor
81+ class SumVisitor_080_MI (SumVisitor_080_BI ):
82+ # inheriting 'target' from BI class
83+
84+ tosa_specs = [
85+ TosaSpecification .create_from_string ("TOSA-0.80.0+MI" ),
86+ ]
87+
88+ def __init__ (self , * args ):
89+ super ().__init__ (* args )
90+
91+ def define_node (
92+ self ,
93+ node : Node ,
94+ tosa_graph : ts .TosaSerializer ,
95+ inputs : List [TosaArg ],
96+ output : TosaArg ,
97+ is_quant_node : bool ,
98+ ) -> None :
99+ if inputs [0 ].dtype == ts .DType .INT8 :
100+ return super ().define_node (node , tosa_graph , inputs , output , is_quant_node )
101+ input_name = inputs [0 ].name
102+ reduced_shape = list (inputs [0 ].shape )
103+ dim_list = cast (list [int ], inputs [1 ].special )
104+ dim_list = [dim % len (reduced_shape ) for dim in dim_list ]
105+ keep_dim = cast (bool , inputs [2 ].number if len (inputs ) > 2 else False )
106+ assert keep_dim , "This case should be handled by InsertSqueezeAfterSumPass"
107+
108+ # Reduce all dims in dim_list one-by-one.
109+ for dim in dim_list :
110+ # When reduced, the size of the dim becomes 1
111+ reduced_shape [dim ] = 1
112+
113+ attr = ts .TosaSerializerAttribute ()
114+ attr .AxisAttribute (inputs [0 ].dim_order .index (dim ))
115+
116+ if dim == dim_list [- 1 ]:
117+ output_name = output .name
118+ else :
119+ output_name = tosa_graph .addIntermediate (
120+ tutils .tosa_shape (reduced_shape , inputs [0 ].dim_order ),
121+ dtype = ts .DType .FP32 ,
122+ ).name
46123
47- # Rescale input to 32 bit
48- rescaled_inputs , scale = tqutils .rescale_nodes_to_int32 (
49- [node .all_input_nodes [0 ]], tosa_graph
124+ tosa_graph .addOperator (
125+ TosaOp .Op ().REDUCE_SUM , [input_name ], [output_name ], attr
50126 )
51127
52- prev_node = rescaled_inputs [0 ]
53- reduced_shape = input_shape
54-
55- # Reduce all dims in dim_list one-by-one.
56- for dim in dim_list :
57- # When reduced, the size of the dim becomes 1.
58- reduced_shape [dim ] = 1
59-
60- attr = ts .TosaSerializerAttribute ()
61- attr .AxisAttribute (input_node .dim_order .index (dim ))
62-
63- next_node = tosa_graph .addIntermediate (
64- tutils .tosa_shape (reduced_shape , input_node .dim_order ),
65- dtype = ts .DType .INT32 ,
66- )
67-
68- tosa_graph .addOperator (
69- TosaOp .Op ().REDUCE_SUM , [prev_node .name ], [next_node .name ], attr
70- )
71-
72- prev_node = next_node
73- tqutils .rescale_node_back_to_int8 (node , prev_node , scale , tosa_graph )
74- else :
75- input_name = input_node .name
76- reduced_shape = input_shape
77-
78- # Reduce all dims in dim_list one-by-one.
79- for dim in dim_list :
80- # When reduced, the size of the dim becomes 1
81- reduced_shape [dim ] = 1
82-
83- attr = ts .TosaSerializerAttribute ()
84- attr .AxisAttribute (input_node .dim_order .index (dim ))
85-
86- if dim == dim_list [- 1 ]:
87- output_name = output .name
88- else :
89- output_name = tosa_graph .addIntermediate (
90- tutils .tosa_shape (reduced_shape , input_node .dim_order ),
91- dtype = ts .DType .FP32 ,
92- ).name
93-
94- tosa_graph .addOperator (
95- TosaOp .Op ().REDUCE_SUM , [input_name ], [output_name ], attr
96- )
97-
98- input_name = output_name
128+ input_name = output_name
0 commit comments