44# LICENSE file in the root directory of this source tree. 
55
66# pyre-unsafe 
7- from  typing  import  List 
7+ from  typing  import  Any ,  List 
88
99import  executorch .backends .arm .tosa_quant_utils  as  tqutils 
1010import  executorch .backends .arm .tosa_utils  as  tutils 
1111
12- import  tosa_tools .v0_80 .serializer .tosa_serializer  as  ts   # type: ignore 
1312from  executorch .backends .arm .operators .node_visitor  import  (
1413    NodeVisitor ,
1514    register_node_visitor ,
@@ -33,10 +32,13 @@ def __init__(self, *args):
3332    def  define_node (
3433        self ,
3534        node : Node ,
36-         tosa_graph : ts . TosaSerializer ,
35+         tosa_graph : Any ,
3736        inputs : List [TosaArg ],
3837        output : TosaArg ,
3938    ) ->  None :
39+ 
40+         import  tosa_tools .v0_80 .serializer .tosa_serializer  as  ts   # type: ignore 
41+ 
4042        # Specification (0.80) states that input and output types 
4143        # should all be the same 
4244        if  not  (inputs [0 ].dtype  ==  output .dtype ):
@@ -53,7 +55,7 @@ def define_node(
5355        if  inputs [0 ].dtype  ==  ts .DType .INT8 :
5456            rescaled_inputs , scale_back  =  tqutils .insert_rescale_ops_to_int32 (
5557                tosa_graph , inputs , node 
56-             )
58+             )   # type: ignore[possibly-undefined] 
5759        else :
5860            # input[0].dtype == ts.DType.INT32 
5961            # Non quantized input, natively support by TOSA.abs 
@@ -96,10 +98,13 @@ def __init__(self, *args):
9698    def  define_node (
9799        self ,
98100        node : Node ,
99-         tosa_graph : ts . TosaSerializer ,
101+         tosa_graph : Any ,
100102        inputs : List [TosaArg ],
101103        output : TosaArg ,
102104    ) ->  None :
105+ 
106+         import  tosa_tools .v0_80 .serializer .tosa_serializer  as  ts   # type: ignore 
107+ 
103108        # Specification (0.80) states that input and output types 
104109        # should all be the same 
105110        if  not  (inputs [0 ].dtype  ==  output .dtype ):
@@ -129,3 +134,122 @@ def define_node(
129134                [output .name ],
130135                None ,
131136            )
137+ 
138+ 
139+ @register_node_visitor  
140+ class  AbsVisitor_INT (NodeVisitor ):
141+     target  =  "aten.abs.default" 
142+ 
143+     tosa_specs  =  [
144+         TosaSpecification .create_from_string ("TOSA-1.0+INT" ),
145+     ]
146+ 
147+     def  __init__ (self , * args ):
148+         super ().__init__ (* args )
149+ 
150+     def  define_node (
151+         self ,
152+         node : Node ,
153+         tosa_graph : Any ,
154+         inputs : List [TosaArg ],
155+         output : TosaArg ,
156+     ) ->  None :
157+ 
158+         import  serializer .tosa_serializer  as  ts   # type: ignore 
159+ 
160+         # Specification (1.0) states that input and output types 
161+         # should all be the same 
162+         if  not  (inputs [0 ].dtype  ==  output .dtype ):
163+             raise  ValueError (
164+                 "All inputs and outputs need same dtype." 
165+                 f"Got { inputs [0 ].dtype = } { output .dtype = }  
166+             )
167+         # Handle int8 (quantized) and int32 
168+         if  not  (inputs [0 ].dtype  in  [ts .DType .INT8 , ts .DType .INT32 ]):
169+             raise  ValueError (
170+                 "All inputs need to be INT8 or INT32."  f"Got { inputs [0 ].dtype = }  
171+             )
172+ 
173+         scale_back  =  1.0 
174+         if  inputs [0 ].dtype  ==  ts .DType .INT8 :
175+             rescaled_inputs , scale_back  =  tqutils .insert_rescale_ops_to_int32 (
176+                 tosa_graph , inputs , node , self .tosa_specs 
177+             )  # type: ignore[possibly-undefined] 
178+         else :
179+             # input[0].dtype == ts.DType.INT32 
180+             # Non quantized input, natively support by TOSA.abs 
181+             rescaled_inputs  =  inputs 
182+ 
183+         if  output .dtype  ==  ts .DType .INT8 :
184+             broadcasted_shape  =  tutils .tosa_shape (output .shape , output .dim_order )
185+             abs_output  =  tosa_graph .addIntermediate (broadcasted_shape , ts .DType .INT32 )
186+         else :
187+             # output.dtype == ts.DType.INT32 
188+             abs_output  =  output 
189+ 
190+         # Do the INT32 Abs 
191+         tosa_graph .addOperator (
192+             ts .TosaOp .Op ().ABS ,
193+             [
194+                 rescaled_inputs [0 ].name ,
195+             ],
196+             [abs_output .name ],
197+             None ,
198+         )
199+ 
200+         if  output .dtype  ==  ts .DType .INT8 :
201+             # Scale output back to 8 bit 
202+             # pyre-ignore 
203+             tqutils .insert_rescale_op_to_int8 (
204+                 tosa_graph , abs_output , scale_back , node , self .tosa_specs 
205+             )  # type: ignore[possibly-undefined] 
206+ 
207+ 
208+ @register_node_visitor  
209+ class  AbsVisitor_FP (AbsVisitor_INT ):
210+     # inheriting 'target' from BI class 
211+ 
212+     tosa_specs  =  [TosaSpecification .create_from_string ("TOSA-1.0+FP" )]
213+ 
214+     def  __init__ (self , * args ):
215+         super ().__init__ (* args )
216+ 
217+     def  define_node (
218+         self ,
219+         node : Node ,
220+         tosa_graph : Any ,
221+         inputs : List [TosaArg ],
222+         output : TosaArg ,
223+     ) ->  None :
224+ 
225+         import  serializer .tosa_serializer  as  ts   # type: ignore 
226+ 
227+         # Specification (1.0) states that input and output types 
228+         # should all be the same 
229+         if  not  (inputs [0 ].dtype  ==  output .dtype ):
230+             raise  ValueError (
231+                 "All inputs and output need same dtype." 
232+                 f"Got { inputs [0 ].dtype = } { output .dtype = }  
233+             )
234+ 
235+         if  inputs [0 ].dtype  in  [ts .DType .INT8 , ts .DType .INT32 ]:
236+             # Call the inherited define_node for handling integers 
237+             super ().define_node (node , tosa_graph , inputs , output )
238+         else :
239+             # FP32 Abs lowering 
240+ 
241+             if  not  (inputs [0 ].dtype  ==  ts .DType .FP32 ):
242+                 raise  ValueError (
243+                     "All inputs need to be FP32."  f"Got { inputs [0 ].dtype = }  
244+                 )
245+ 
246+             if  not  (output .dtype  ==  ts .DType .FP32 ):
247+                 raise  ValueError ("All outputs need to be FP32."  f"Got { output .dtype = }  )
248+ 
249+             # MI lowering 
250+             tosa_graph .addOperator (
251+                 ts .TosaOp .Op ().ABS ,
252+                 [inputs [0 ].name ],
253+                 [output .name ],
254+                 None ,
255+             )
0 commit comments