55
66# pyre-unsafe
77
8- from typing import List
8+ from typing import Any , List
99
1010import executorch .backends .arm .tosa_quant_utils as tqutils
1111import executorch .backends .arm .tosa_utils as tutils
1212
13- import tosa_tools .v0_80 .serializer .tosa_serializer as ts # type: ignore
1413from executorch .backends .arm .operators .node_visitor import (
1514 NodeVisitor ,
1615 register_node_visitor ,
@@ -34,10 +33,13 @@ def __init__(self, *args):
3433 def define_node (
3534 self ,
3635 node : Node ,
37- tosa_graph : ts . TosaSerializer ,
36+ tosa_graph : Any ,
3837 inputs : List [TosaArg ],
3938 output : TosaArg ,
4039 ) -> None :
40+
41+ import tosa_tools .v0_80 .serializer .tosa_serializer as ts # type: ignore
42+
4143 # Specification (0.80) states that input and output types
4244 # should all be the same
4345 if inputs [0 ].dtype != inputs [1 ].dtype or inputs [0 ].dtype != output .dtype :
@@ -58,7 +60,7 @@ def define_node(
5860 if len (inputs [0 ].shape ) > len (inputs [1 ].shape )
5961 else inputs [1 ].dim_order
6062 )
61-
63+ scale_back = 1.0
6264 if inputs [0 ].dtype == ts .DType .INT8 :
6365 rescaled_inputs , scale_back = tqutils .insert_rescale_ops_to_int32 (
6466 tosa_graph , inputs , node
@@ -90,7 +92,9 @@ def define_node(
9092 if output .dtype == ts .DType .INT8 :
9193 # Scale output back to 8 bit
9294 # pyre-ignore
93- tqutils .insert_rescale_op_to_int8 (tosa_graph , add_output , scale_back , node ) # type: ignore[possibly-undefined]
95+ tqutils .insert_rescale_op_to_int8 (
96+ tosa_graph , add_output , scale_back , node
97+ ) # type: ignore[possibly-undefined]
9498
9599
96100@register_node_visitor
@@ -107,10 +111,13 @@ def __init__(self, *args):
107111 def define_node (
108112 self ,
109113 node : Node ,
110- tosa_graph : ts . TosaSerializer ,
114+ tosa_graph : Any ,
111115 inputs : List [TosaArg ],
112116 output : TosaArg ,
113117 ) -> None :
118+
119+ import tosa_tools .v0_80 .serializer .tosa_serializer as ts # type: ignore
120+
114121 # Specification (0.80) states that input and output types
115122 # should all be the same
116123 if inputs [0 ].dtype != inputs [1 ].dtype or inputs [0 ].dtype != output .dtype :
@@ -130,7 +137,7 @@ def define_node(
130137 f"Expected IO data type to be FP32, got { inputs [0 ].dtype } "
131138 )
132139
133- input1 , input2 = tutils . reshape_for_broadcast ( tosa_graph , inputs )
140+ input1 , input2 = inputs
134141
135142 # MI lowering
136143 tosa_graph .addOperator (
@@ -139,3 +146,122 @@ def define_node(
139146 [output .name ],
140147 None ,
141148 )
149+
150+
151+ @register_node_visitor
152+ class AddVisitor_INT (NodeVisitor ):
153+ target = "aten.add.Tensor"
154+
155+ tosa_specs = [
156+ TosaSpecification .create_from_string ("TOSA-1.0+INT" ),
157+ ]
158+
159+ def __init__ (self , * args ):
160+ super ().__init__ (* args )
161+
162+ def define_node (
163+ self ,
164+ node : Node ,
165+ tosa_graph : Any ,
166+ inputs : List [TosaArg ],
167+ output : TosaArg ,
168+ ) -> None :
169+
170+ import serializer .tosa_serializer as ts # type: ignore
171+
172+ # Specification (1.0) states that input and output types
173+ # should all be the same
174+ if inputs [0 ].dtype != inputs [1 ].dtype or inputs [0 ].dtype != output .dtype :
175+ raise TypeError (
176+ f"All IO needs to have the same data type, got input 1: "
177+ f"{ inputs [0 ].dtype } , input 2: { inputs [1 ].dtype } and output: "
178+ f"{ output .dtype } "
179+ )
180+ # Handle int8 (quantized) and int32
181+ supported_dtypes = [ts .DType .INT8 , ts .DType .INT32 ]
182+ if inputs [0 ].dtype not in supported_dtypes :
183+ raise TypeError (
184+ f'IO data type needs to be { supported_dtypes } , got "{ inputs [0 ].dtype } "'
185+ )
186+ scale_back = 1.0
187+ if inputs [0 ].dtype == ts .DType .INT8 :
188+ rescaled_inputs , scale_back = tqutils .insert_rescale_ops_to_int32 (
189+ tosa_graph , inputs , node , self .tosa_specs
190+ )
191+ else :
192+ # input[0].dtype == ts.DType.INT32
193+ # Non quantized input, natively support by TOSA.ADD
194+ rescaled_inputs = inputs
195+
196+ if output .dtype == ts .DType .INT8 :
197+ broadcasted_shape = tutils .tosa_shape (output .shape , output .dim_order )
198+ add_output = tosa_graph .addIntermediate (broadcasted_shape , ts .DType .INT32 )
199+ else :
200+ # output.dtype == ts.DType.INT32
201+ add_output = output
202+
203+ input1 , input2 = rescaled_inputs
204+
205+ # Do the INT32 Add
206+ tosa_graph .addOperator (
207+ ts .TosaOp .Op ().ADD ,
208+ [input1 .name , input2 .name ],
209+ [add_output .name ],
210+ None ,
211+ )
212+
213+ if output .dtype == ts .DType .INT8 :
214+ # Scale output back to 8 bit
215+ # pyre-ignore
216+ tqutils .insert_rescale_op_to_int8 (
217+ tosa_graph , add_output , scale_back , node , self .tosa_specs
218+ ) # type: ignore[possibly-undefined]
219+
220+
221+ @register_node_visitor
222+ class AddVisitor_FP (AddVisitor_INT ):
223+ # inheriting 'target' from INT class
224+
225+ tosa_specs = [TosaSpecification .create_from_string ("TOSA-1.0+FP" )]
226+
227+ def __init__ (self , * args ):
228+ super ().__init__ (* args )
229+
230+ def define_node (
231+ self ,
232+ node : Node ,
233+ tosa_graph : Any ,
234+ inputs : List [TosaArg ],
235+ output : TosaArg ,
236+ ) -> None :
237+
238+ import serializer .tosa_serializer as ts # type: ignore
239+
240+ # Specification (1.0) states that input and output types
241+ # should all be the same
242+ if inputs [0 ].dtype != inputs [1 ].dtype or inputs [0 ].dtype != output .dtype :
243+ raise TypeError (
244+ f"All IO needs to have the same data type, got input 1: "
245+ f"{ inputs [0 ].dtype } , input 2: { inputs [1 ].dtype } and output: "
246+ f"{ output .dtype } "
247+ )
248+
249+ if inputs [0 ].dtype in [ts .DType .INT8 , ts .DType .INT32 ]:
250+ # Call the inherited define_node for handling integers
251+ super ().define_node (node , tosa_graph , inputs , output )
252+ else :
253+ # FP32 Add lowering
254+ if inputs [0 ].dtype != ts .DType .FP32 :
255+ raise TypeError (
256+ f"Expected IO data type to be FP32, got { inputs [0 ].dtype } "
257+ )
258+
259+ input1 , input2 = inputs
260+
261+ # FP lowering
262+ tosa_graph .addOperator (
263+ ts .TosaOp .Op ().ADD ,
264+ [input1 .name , input2 .name ],
265+ [output .name ],
266+ None ,
267+ )
0 commit comments