1111import executorch .backends .arm .tosa_utils as tutils
1212
1313import serializer .tosa_serializer as ts
14+ import torch
1415from executorch .backends .arm .operators .node_visitor import (
1516 NodeVisitor ,
1617 register_node_visitor ,
@@ -40,27 +41,33 @@ def define_node(
4041 output : TosaArg ,
4142 is_quant_node : bool ,
4243 ) -> None :
43- # Specification (0.80) states that input and output types
44- # should all be the same
45- assert inputs [0 ].dtype == inputs [1 ].dtype == output .dtype
46- # Handle int8 (quantized) and int32
47- assert inputs [0 ].dtype in [ts .DType .INT8 , ts .DType .INT32 ]
48-
49- if inputs [0 ].dtype == ts .DType .INT8 :
50- rescaled_inputs , scale_back = tqutils .insert_rescale_ops_to_int32 (
51- tosa_graph , inputs , node
44+ input_nodes = tutils .get_two_inputs (node )
45+
46+ if not is_quant_node and not all (
47+ tensor .meta ["val" ].dtype in (torch .int8 , torch .int32 )
48+ for tensor in input_nodes
49+ ):
50+ raise RuntimeError (
51+ f"Unexpected non quantized { AddVisitor_080_BI .target } node."
5252 )
53- else :
54- # input[0].dtype == ts.DType.INT32
55- # Non quantized input, natively support by TOSA.ADD
56- rescaled_inputs = inputs
5753
58- if output .dtype == ts .DType .INT8 :
54+ needs_rescale = not (
55+ all (tensor .meta ["val" ].dtype == torch .int32 for tensor in input_nodes )
56+ and node .meta ["val" ].dtype == torch .int32
57+ )
58+
59+ if needs_rescale :
60+ # Rescale inputs to 32 bit
61+ rescaled_inputs , scale = tqutils .rescale_nodes_to_int32 (
62+ input_nodes , tosa_graph
63+ )
64+
65+ # Prepare add output tensor
5966 broadcasted_shape = tutils .tosa_shape (output .shape , output .dim_order )
6067 add_output = tosa_graph .addIntermediate (broadcasted_shape , ts .DType .INT32 )
6168 else :
62- # output.dtype == ts.DType.INT32
6369 add_output = output
70+ rescaled_inputs = inputs
6471
6572 # Do the INT32 Add
6673 tosa_graph .addOperator (
@@ -73,10 +80,10 @@ def define_node(
7380 None ,
7481 )
7582
76- if output . dtype == ts . DType . INT8 :
83+ if needs_rescale :
7784 # Scale output back to 8 bit
7885 # pyre-ignore
79- tqutils .insert_rescale_op_to_int8 ( tosa_graph , add_output , scale_back , node )
86+ tqutils .rescale_node_back_to_int8 ( node , add_output , scale , tosa_graph )
8087
8188
8289@register_node_visitor
@@ -98,19 +105,11 @@ def define_node(
98105 output : TosaArg ,
99106 is_quant_node : bool ,
100107 ) -> None :
101- # Specification (0.80) states that input and output types
102- # should all be the same
103- assert inputs [0 ].dtype == inputs [1 ].dtype == output .dtype
104-
105- if inputs [0 ].dtype in [ts .DType .INT8 , ts .DType .INT32 ]:
108+ if is_quant_node :
106109 # Call the inherited define_node for handling integers
107110 super ().define_node (node , tosa_graph , inputs , output , is_quant_node )
108111 else :
109112 # FP32 Add lowering
110- assert inputs [0 ].dtype == ts .DType .FP32
111- assert output .dtype == ts .DType .FP32
112-
113- # MI lowering
114113 tosa_graph .addOperator (
115114 TosaOp .Op ().ADD ,
116115 [inputs [0 ].name , inputs [1 ].name ],
0 commit comments