1111import executorch .backends .arm .tosa_utils as tutils
1212
1313import serializer .tosa_serializer as ts
14- import torch
1514from executorch .backends .arm .operators .node_visitor import (
1615 NodeVisitor ,
1716 register_node_visitor ,
@@ -41,33 +40,27 @@ def define_node(
4140 output : TosaArg ,
4241 is_quant_node : bool ,
4342 ) -> None :
44- input_nodes = tutils .get_two_inputs (node )
45-
46- if not is_quant_node and not all (
47- tensor .meta ["val" ].dtype in (torch .int8 , torch .int32 )
48- for tensor in input_nodes
49- ):
50- raise RuntimeError (
51- f"Unexpected non quantized { AddVisitor_080_BI .target } node."
52- )
53-
54- needs_rescale = not (
55- all (tensor .meta ["val" ].dtype == torch .int32 for tensor in input_nodes )
56- and node .meta ["val" ].dtype == torch .int32
57- )
58-
59- if needs_rescale :
60- # Rescale inputs to 32 bit
61- rescaled_inputs , scale = tqutils .rescale_nodes_to_int32 (
62- input_nodes , tosa_graph
43+ # Specification (0.80.0) states that input and output types
44+ # should all be the same
45+ assert inputs [0 ].dtype == inputs [1 ].dtype == output .dtype
46+ # Handle int8 (quantized) and int32
47+ assert inputs [0 ].dtype in [ts .DType .INT8 , ts .DType .INT32 ]
48+
49+ if inputs [0 ].dtype == ts .DType .INT8 :
50+ rescaled_inputs , scale_back = tqutils .insert_rescale_ops_to_int32 (
51+ tosa_graph , inputs , node
6352 )
53+ else :
54+ # input[0].dtype == ts.DType.INT32
55+ # Non quantized input, natively support by TOSA.ADD
56+ rescaled_inputs = inputs
6457
65- # Prepare add output tensor
58+ if output . dtype == ts . DType . INT8 :
6659 broadcasted_shape = tutils .tosa_shape (output .shape , output .dim_order )
6760 add_output = tosa_graph .addIntermediate (broadcasted_shape , ts .DType .INT32 )
6861 else :
62+ # output.dtype == ts.DType.INT32
6963 add_output = output
70- rescaled_inputs = inputs
7164
7265 # Do the INT32 Add
7366 tosa_graph .addOperator (
@@ -80,10 +73,12 @@ def define_node(
8073 None ,
8174 )
8275
83- if needs_rescale :
76+ if output . dtype == ts . DType . INT8 :
8477 # Scale output back to 8 bit
8578 # pyre-ignore
86- tqutils .rescale_node_back_to_int8 (node , add_output , scale , tosa_graph )
79+ tqutils .insert_rescale_node_back_to_int8 (
80+ tosa_graph , add_output , scale_back , node
81+ )
8782
8883
8984@register_node_visitor
@@ -105,11 +100,19 @@ def define_node(
105100 output : TosaArg ,
106101 is_quant_node : bool ,
107102 ) -> None :
108- if is_quant_node :
103+ # Specification (0.80.0) states that input and output types
104+ # should all be the same
105+ assert inputs [0 ].dtype == inputs [1 ].dtype == output .dtype
106+
107+ if inputs [0 ].dtype in [ts .DType .INT8 , ts .DType .INT32 ]:
109108 # Call the inherited define_node for handling integers
110109 super ().define_node (node , tosa_graph , inputs , output , is_quant_node )
111110 else :
112111 # FP32 Add lowering
112+ assert inputs [0 ].dtype == ts .DType .FP32
113+ assert output .dtype == ts .DType .FP32
114+
115+ # MI lowering
113116 tosa_graph .addOperator (
114117 TosaOp .Op ().ADD ,
115118 [inputs [0 ].name , inputs [1 ].name ],
0 commit comments