44# LICENSE file in the root directory of this source tree.
55
66# pyre-unsafe
7- from typing import List
7+ from typing import Any , List
88
99import executorch .backends .arm .tosa_quant_utils as tqutils
1010import executorch .backends .arm .tosa_utils as tutils
1111
12- import tosa_tools .v0_80 .serializer .tosa_serializer as ts # type: ignore
1312from executorch .backends .arm .operators .node_visitor import (
1413 NodeVisitor ,
1514 register_node_visitor ,
@@ -33,10 +32,13 @@ def __init__(self, *args):
3332 def define_node (
3433 self ,
3534 node : Node ,
36- tosa_graph : ts . TosaSerializer ,
35+ tosa_graph : Any ,
3736 inputs : List [TosaArg ],
3837 output : TosaArg ,
3938 ) -> None :
39+
40+ import tosa_tools .v0_80 .serializer .tosa_serializer as ts # type: ignore
41+
4042 # Specification (0.80) states that input and output types
4143 # should all be the same
4244 if not (inputs [0 ].dtype == output .dtype ):
@@ -53,7 +55,7 @@ def define_node(
5355 if inputs [0 ].dtype == ts .DType .INT8 :
5456 rescaled_inputs , scale_back = tqutils .insert_rescale_ops_to_int32 (
5557 tosa_graph , inputs , node
56- )
58+ ) # type: ignore[possibly-undefined]
5759 else :
5860 # input[0].dtype == ts.DType.INT32
5961 # Non quantized input, natively support by TOSA.abs
@@ -96,10 +98,13 @@ def __init__(self, *args):
9698 def define_node (
9799 self ,
98100 node : Node ,
99- tosa_graph : ts . TosaSerializer ,
101+ tosa_graph : Any ,
100102 inputs : List [TosaArg ],
101103 output : TosaArg ,
102104 ) -> None :
105+
106+ import tosa_tools .v0_80 .serializer .tosa_serializer as ts # type: ignore
107+
103108 # Specification (0.80) states that input and output types
104109 # should all be the same
105110 if not (inputs [0 ].dtype == output .dtype ):
@@ -129,3 +134,122 @@ def define_node(
129134 [output .name ],
130135 None ,
131136 )
137+
138+
139+ @register_node_visitor
140+ class AbsVisitor_INT (NodeVisitor ):
141+ target = "aten.abs.default"
142+
143+ tosa_specs = [
144+ TosaSpecification .create_from_string ("TOSA-1.0+INT" ),
145+ ]
146+
147+ def __init__ (self , * args ):
148+ super ().__init__ (* args )
149+
150+ def define_node (
151+ self ,
152+ node : Node ,
153+ tosa_graph : Any ,
154+ inputs : List [TosaArg ],
155+ output : TosaArg ,
156+ ) -> None :
157+
158+ import serializer .tosa_serializer as ts # type: ignore
159+
160+ # Specification (1.0) states that input and output types
161+ # should all be the same
162+ if not (inputs [0 ].dtype == output .dtype ):
163+ raise ValueError (
164+ "All inputs and outputs need same dtype."
165+ f"Got { inputs [0 ].dtype = } , { output .dtype = } "
166+ )
167+ # Handle int8 (quantized) and int32
168+ if not (inputs [0 ].dtype in [ts .DType .INT8 , ts .DType .INT32 ]):
169+ raise ValueError (
170+ "All inputs need to be INT8 or INT32." f"Got { inputs [0 ].dtype = } "
171+ )
172+
173+ scale_back = 1.0
174+ if inputs [0 ].dtype == ts .DType .INT8 :
175+ rescaled_inputs , scale_back = tqutils .insert_rescale_ops_to_int32 (
176+ tosa_graph , inputs , node , self .tosa_specs
177+ ) # type: ignore[possibly-undefined]
178+ else :
179+ # input[0].dtype == ts.DType.INT32
180+ # Non quantized input, natively support by TOSA.abs
181+ rescaled_inputs = inputs
182+
183+ if output .dtype == ts .DType .INT8 :
184+ broadcasted_shape = tutils .tosa_shape (output .shape , output .dim_order )
185+ abs_output = tosa_graph .addIntermediate (broadcasted_shape , ts .DType .INT32 )
186+ else :
187+ # output.dtype == ts.DType.INT32
188+ abs_output = output
189+
190+ # Do the INT32 Abs
191+ tosa_graph .addOperator (
192+ ts .TosaOp .Op ().ABS ,
193+ [
194+ rescaled_inputs [0 ].name ,
195+ ],
196+ [abs_output .name ],
197+ None ,
198+ )
199+
200+ if output .dtype == ts .DType .INT8 :
201+ # Scale output back to 8 bit
202+ # pyre-ignore
203+ tqutils .insert_rescale_op_to_int8 (
204+ tosa_graph , abs_output , scale_back , node , self .tosa_specs
205+ ) # type: ignore[possibly-undefined]
206+
207+
208+ @register_node_visitor
209+ class AbsVisitor_FP (AbsVisitor_INT ):
210+ # inheriting 'target' from BI class
211+
212+ tosa_specs = [TosaSpecification .create_from_string ("TOSA-1.0+FP" )]
213+
214+ def __init__ (self , * args ):
215+ super ().__init__ (* args )
216+
217+ def define_node (
218+ self ,
219+ node : Node ,
220+ tosa_graph : Any ,
221+ inputs : List [TosaArg ],
222+ output : TosaArg ,
223+ ) -> None :
224+
225+ import serializer .tosa_serializer as ts # type: ignore
226+
227+ # Specification (1.0) states that input and output types
228+ # should all be the same
229+ if not (inputs [0 ].dtype == output .dtype ):
230+ raise ValueError (
231+ "All inputs and output need same dtype."
232+ f"Got { inputs [0 ].dtype = } , { output .dtype = } "
233+ )
234+
235+ if inputs [0 ].dtype in [ts .DType .INT8 , ts .DType .INT32 ]:
236+ # Call the inherited define_node for handling integers
237+ super ().define_node (node , tosa_graph , inputs , output )
238+ else :
239+ # FP32 Abs lowering
240+
241+ if not (inputs [0 ].dtype == ts .DType .FP32 ):
242+ raise ValueError (
243+ "All inputs need to be FP32." f"Got { inputs [0 ].dtype = } "
244+ )
245+
246+ if not (output .dtype == ts .DType .FP32 ):
247+ raise ValueError ("All outputs need to be FP32." f"Got { output .dtype = } " )
248+
249+ # MI lowering
250+ tosa_graph .addOperator (
251+ ts .TosaOp .Op ().ABS ,
252+ [inputs [0 ].name ],
253+ [output .name ],
254+ None ,
255+ )
0 commit comments