55
66# pyre-unsafe 
77
8- from  typing  import  List 
8+ from  typing  import  Any ,  List 
99
1010import  executorch .backends .arm .tosa_quant_utils  as  tqutils 
1111import  executorch .backends .arm .tosa_utils  as  tutils 
1212
13- import  tosa_tools .v0_80 .serializer .tosa_serializer  as  ts   # type: ignore 
1413from  executorch .backends .arm .operators .node_visitor  import  (
1514    NodeVisitor ,
1615    register_node_visitor ,
@@ -34,10 +33,13 @@ def __init__(self, *args):
3433    def  define_node (
3534        self ,
3635        node : Node ,
37-         tosa_graph : ts . TosaSerializer ,
36+         tosa_graph : Any ,
3837        inputs : List [TosaArg ],
3938        output : TosaArg ,
4039    ) ->  None :
40+ 
41+         import  tosa_tools .v0_80 .serializer .tosa_serializer  as  ts   # type: ignore 
42+ 
4143        # Specification (0.80) states that input and output types 
4244        # should all be the same 
4345        if  inputs [0 ].dtype  !=  inputs [1 ].dtype  or  inputs [0 ].dtype  !=  output .dtype :
@@ -58,7 +60,7 @@ def define_node(
5860            if  len (inputs [0 ].shape ) >  len (inputs [1 ].shape )
5961            else  inputs [1 ].dim_order 
6062        )
61- 
63+          scale_back   =   1.0 
6264        if  inputs [0 ].dtype  ==  ts .DType .INT8 :
6365            rescaled_inputs , scale_back  =  tqutils .insert_rescale_ops_to_int32 (
6466                tosa_graph , inputs , node 
@@ -90,7 +92,9 @@ def define_node(
9092        if  output .dtype  ==  ts .DType .INT8 :
9193            # Scale output back to 8 bit 
9294            # pyre-ignore 
93-             tqutils .insert_rescale_op_to_int8 (tosa_graph , add_output , scale_back , node )  # type: ignore[possibly-undefined] 
95+             tqutils .insert_rescale_op_to_int8 (
96+                 tosa_graph , add_output , scale_back , node 
97+             )  # type: ignore[possibly-undefined] 
9498
9599
96100@register_node_visitor  
@@ -107,10 +111,13 @@ def __init__(self, *args):
107111    def  define_node (
108112        self ,
109113        node : Node ,
110-         tosa_graph : ts . TosaSerializer ,
114+         tosa_graph : Any ,
111115        inputs : List [TosaArg ],
112116        output : TosaArg ,
113117    ) ->  None :
118+ 
119+         import  tosa_tools .v0_80 .serializer .tosa_serializer  as  ts   # type: ignore 
120+ 
114121        # Specification (0.80) states that input and output types 
115122        # should all be the same 
116123        if  inputs [0 ].dtype  !=  inputs [1 ].dtype  or  inputs [0 ].dtype  !=  output .dtype :
@@ -130,7 +137,7 @@ def define_node(
130137                    f"Expected IO data type to be FP32, got { inputs [0 ].dtype }  
131138                )
132139
133-             input1 , input2  =  tutils . reshape_for_broadcast ( tosa_graph ,  inputs ) 
140+             input1 , input2  =  inputs 
134141
135142            # MI lowering 
136143            tosa_graph .addOperator (
@@ -139,3 +146,122 @@ def define_node(
139146                [output .name ],
140147                None ,
141148            )
149+ 
150+ 
151+ @register_node_visitor  
152+ class  AddVisitor_INT (NodeVisitor ):
153+     target  =  "aten.add.Tensor" 
154+ 
155+     tosa_specs  =  [
156+         TosaSpecification .create_from_string ("TOSA-1.0+INT" ),
157+     ]
158+ 
159+     def  __init__ (self , * args ):
160+         super ().__init__ (* args )
161+ 
162+     def  define_node (
163+         self ,
164+         node : Node ,
165+         tosa_graph : Any ,
166+         inputs : List [TosaArg ],
167+         output : TosaArg ,
168+     ) ->  None :
169+ 
170+         import  serializer .tosa_serializer  as  ts   # type: ignore 
171+ 
172+         # Specification (1.0) states that input and output types 
173+         # should all be the same 
174+         if  inputs [0 ].dtype  !=  inputs [1 ].dtype  or  inputs [0 ].dtype  !=  output .dtype :
175+             raise  TypeError (
176+                 f"All IO needs to have the same data type, got input 1: " 
177+                 f"{ inputs [0 ].dtype } { inputs [1 ].dtype }  
178+                 f"{ output .dtype }  
179+             )
180+         # Handle int8 (quantized) and int32 
181+         supported_dtypes  =  [ts .DType .INT8 , ts .DType .INT32 ]
182+         if  inputs [0 ].dtype  not  in supported_dtypes :
183+             raise  TypeError (
184+                 f'IO data type needs to be { supported_dtypes } { inputs [0 ].dtype }  
185+             )
186+         scale_back  =  1.0 
187+         if  inputs [0 ].dtype  ==  ts .DType .INT8 :
188+             rescaled_inputs , scale_back  =  tqutils .insert_rescale_ops_to_int32 (
189+                 tosa_graph , inputs , node , self .tosa_specs 
190+             )
191+         else :
192+             # input[0].dtype == ts.DType.INT32 
193+             # Non quantized input, natively support by TOSA.ADD 
194+             rescaled_inputs  =  inputs 
195+ 
196+         if  output .dtype  ==  ts .DType .INT8 :
197+             broadcasted_shape  =  tutils .tosa_shape (output .shape , output .dim_order )
198+             add_output  =  tosa_graph .addIntermediate (broadcasted_shape , ts .DType .INT32 )
199+         else :
200+             # output.dtype == ts.DType.INT32 
201+             add_output  =  output 
202+ 
203+         input1 , input2  =  rescaled_inputs 
204+ 
205+         # Do the INT32 Add 
206+         tosa_graph .addOperator (
207+             ts .TosaOp .Op ().ADD ,
208+             [input1 .name , input2 .name ],
209+             [add_output .name ],
210+             None ,
211+         )
212+ 
213+         if  output .dtype  ==  ts .DType .INT8 :
214+             # Scale output back to 8 bit 
215+             # pyre-ignore 
216+             tqutils .insert_rescale_op_to_int8 (
217+                 tosa_graph , add_output , scale_back , node , self .tosa_specs 
218+             )  # type: ignore[possibly-undefined] 
219+ 
220+ 
221+ @register_node_visitor  
222+ class  AddVisitor_FP (AddVisitor_INT ):
223+     # inheriting 'target' from INT class 
224+ 
225+     tosa_specs  =  [TosaSpecification .create_from_string ("TOSA-1.0+FP" )]
226+ 
227+     def  __init__ (self , * args ):
228+         super ().__init__ (* args )
229+ 
230+     def  define_node (
231+         self ,
232+         node : Node ,
233+         tosa_graph : Any ,
234+         inputs : List [TosaArg ],
235+         output : TosaArg ,
236+     ) ->  None :
237+ 
238+         import  serializer .tosa_serializer  as  ts   # type: ignore 
239+ 
240+         # Specification (1.0) states that input and output types 
241+         # should all be the same 
242+         if  inputs [0 ].dtype  !=  inputs [1 ].dtype  or  inputs [0 ].dtype  !=  output .dtype :
243+             raise  TypeError (
244+                 f"All IO needs to have the same data type, got input 1: " 
245+                 f"{ inputs [0 ].dtype } { inputs [1 ].dtype }  
246+                 f"{ output .dtype }  
247+             )
248+ 
249+         if  inputs [0 ].dtype  in  [ts .DType .INT8 , ts .DType .INT32 ]:
250+             # Call the inherited define_node for handling integers 
251+             super ().define_node (node , tosa_graph , inputs , output )
252+         else :
253+             # FP32 Add lowering 
254+             if  inputs [0 ].dtype  !=  ts .DType .FP32 :
255+                 raise  TypeError (
256+                     f"Expected IO data type to be FP32, got { inputs [0 ].dtype }  
257+                 )
258+ 
259+             input1 , input2  =  inputs 
260+ 
261+             # FP lowering 
262+             tosa_graph .addOperator (
263+                 ts .TosaOp .Op ().ADD ,
264+                 [input1 .name , input2 .name ],
265+                 [output .name ],
266+                 None ,
267+             )
0 commit comments