1111import executorch .backends .arm .tosa_utils as tutils
1212
1313import serializer .tosa_serializer as ts
14+ import torch
1415from executorch .backends .arm .operators .node_visitor import (
1516 NodeVisitor ,
1617 register_node_visitor ,
1718)
1819from executorch .backends .arm .tosa_mapping import TosaArg
20+ from executorch .backends .arm .tosa_specification import TosaSpecification
1921from serializer .tosa_serializer import TosaOp
2022from torch .fx import Node
2123
2224
2325@register_node_visitor
24- class AddVisitor (NodeVisitor ):
26+ class AddVisitor_080_BI (NodeVisitor ):
2527 target = "aten.add.Tensor"
2628
29+ tosa_specs = [
30+ TosaSpecification .create_from_string ("TOSA-0.80.0+BI" ),
31+ ]
32+
2733 def __init__ (self , * args ):
2834 super ().__init__ (* args )
2935
@@ -35,9 +41,22 @@ def define_node(
3541 output : TosaArg ,
3642 is_quant_node : bool ,
3743 ) -> None :
38- if is_quant_node :
39- input_nodes = tutils .get_two_inputs (node )
44+ input_nodes = tutils .get_two_inputs (node )
45+
46+ if not is_quant_node and not all (
47+ tensor .meta ["val" ].dtype in (torch .int8 , torch .int32 )
48+ for tensor in input_nodes
49+ ):
50+ raise RuntimeError (
51+ f"Unexpected non quantized { AddVisitor_080_BI .target } node."
52+ )
4053
54+ needs_rescale = not (
55+ all (tensor .meta ["val" ].dtype == torch .int32 for tensor in input_nodes )
56+ and node .meta ["val" ].dtype == torch .int32
57+ )
58+
59+ if needs_rescale :
4160 # Rescale inputs to 32 bit
4261 rescaled_inputs , scale = tqutils .rescale_nodes_to_int32 (
4362 input_nodes , tosa_graph
@@ -48,20 +67,48 @@ def define_node(
4867 rescaled_inputs [0 ].shape , rescaled_inputs [0 ].shape
4968 )
5069 add_output = tosa_graph .addIntermediate (broadcasted_shape , ts .DType .INT32 )
70+ else :
71+ add_output = output
72+ rescaled_inputs = inputs
5173
52- # Do the INT32 Add
53- tosa_graph .addOperator (
54- TosaOp .Op ().ADD ,
55- [
56- rescaled_inputs [0 ].name ,
57- rescaled_inputs [1 ].name ,
58- ],
59- [add_output .name ],
60- None ,
61- )
74+ # Do the INT32 Add
75+ tosa_graph .addOperator (
76+ TosaOp .Op ().ADD ,
77+ [
78+ rescaled_inputs [0 ].name ,
79+ rescaled_inputs [1 ].name ,
80+ ],
81+ [add_output .name ],
82+ None ,
83+ )
6284
85+ if needs_rescale :
6386 # Scale output back to 8 bit
6487 tqutils .rescale_node_back_to_int8 (node , add_output , scale , tosa_graph )
88+
89+
90+ @register_node_visitor
91+ class AddVisitor_080_MI (AddVisitor_080_BI ):
92+ # inheriting 'target' from BI class
93+
94+ tosa_specs = [
95+ TosaSpecification .create_from_string ("TOSA-0.80.0+MI" ),
96+ ]
97+
98+ def __init__ (self , * args ):
99+ super ().__init__ (* args )
100+
101+ def define_node (
102+ self ,
103+ node : Node ,
104+ tosa_graph : ts .TosaSerializer ,
105+ inputs : List [TosaArg ],
106+ output : TosaArg ,
107+ is_quant_node : bool ,
108+ ) -> None :
109+ if is_quant_node :
110+ # Call the inherited define_node for handling integers
111+ super ().define_node (node , tosa_graph , inputs , output , is_quant_node )
65112 else :
66113 # FP32 Add lowering
67114 tosa_graph .addOperator (
0 commit comments