77
88from typing import Any , List
99
10- import executorch .backends .arm .tosa .quant_utils as tqutils
11- import executorch .backends .arm .tosa .utils as tutils
1210import torch
1311
14- from executorch .backends .arm ._passes .fold_qdq_with_annotated_qparams_pass import (
15- get_input_qparams ,
16- )
17-
1812from executorch .backends .arm .operators .node_visitor import (
1913 NodeVisitor ,
2014 register_node_visitor ,
2418 validate_same_dtype ,
2519 validate_valid_dtype ,
2620)
27- from executorch .backends .arm .tosa import TosaSpecification
2821from executorch .backends .arm .tosa .mapping import TosaArg
22+ from executorch .backends .arm .tosa .specification import TosaSpecification
2923
3024
3125@register_node_visitor
32- class MulVisitor_INT (NodeVisitor ):
26+ class MulVisitor (NodeVisitor ):
3327 target = "aten.mul.Tensor"
3428
3529 tosa_specs = [
30+ TosaSpecification .create_from_string ("TOSA-1.0+FP" ),
3631 TosaSpecification .create_from_string ("TOSA-1.0+INT" ),
37- TosaSpecification .create_from_string ("TOSA-1.0+INT+int16" ),
3832 ]
3933
4034 def define_node (
@@ -52,105 +46,13 @@ def define_node(
5246 validate_valid_dtype (
5347 self .target ,
5448 [* inputs , output ],
55- [ts .DType .INT8 , ts .DType .INT16 , ts . DType . INT32 ],
49+ [ts .DType .INT32 , ts .DType .FP32 ],
5650 output .tosa_spec ,
5751 )
5852
59- if inputs [0 ].dtype == ts .DType .INT8 or inputs [0 ].dtype == ts .DType .INT16 :
60- input_A = inputs [0 ]
61- input_B = inputs [1 ]
62- input_qparams = get_input_qparams (node )
63- input_A_qargs = input_qparams [0 ]
64- input_B_qargs = input_qparams [1 ]
65- input_A .shape = tutils .tosa_shape (input_A .shape , input_A .dim_order )
66- input_B .shape = tutils .tosa_shape (input_B .shape , input_B .dim_order )
67-
68- # Rescale inputs to INT32 with zp=0
69- input_A_rescaled = tqutils .build_rescale_to_int32 (
70- tosa_graph ,
71- input_A ,
72- input_A_qargs .get_zp_per_tensor (),
73- 1.0 ,
74- tosa_spec = self .tosa_spec ,
75- )
76- input_B_rescaled = tqutils .build_rescale_to_int32 (
77- tosa_graph ,
78- input_B ,
79- input_B_qargs .get_zp_per_tensor (),
80- 1.0 ,
81- tosa_spec = self .tosa_spec ,
82- )
83- else :
84- # input[0].dtype == ts.DType.INT16 or ts.DType.INT32
85- # Non quantized input, natively support by TOSA.MUL
86- input_A_rescaled , input_B_rescaled = inputs [0 ], inputs [1 ]
87-
88- if output .dtype == ts .DType .INT8 or output .dtype == ts .DType .INT16 :
89- output_shape = tutils .tosa_shape (output .shape , output .dim_order )
90- mul_output = tosa_graph .addIntermediate (output_shape , ts .DType .INT32 )
91- else :
92- # output.dtype == ts.DType.INT32 (non-quantized)
93- mul_output = output
94-
95- # Do the INT32 Mul
96- tosa_graph .addConst ([1 ], ts .DType .INT8 , 0 , name = f"{ node .name } _shift" )
97- self ._serialize_operator (
98- node ,
99- tosa_graph ,
100- ts .TosaOp .Op ().MUL ,
101- [input_A_rescaled .name , input_B_rescaled .name , f"{ node .name } _shift" ],
102- [mul_output .name ],
103- )
104-
105- if output .dtype == ts .DType .INT8 :
106- # Scale output back to 8 bit
107- output_scale = (
108- input_A_qargs .get_scale_per_tensor () # type: ignore[possibly-undefined]
109- * input_B_qargs .get_scale_per_tensor () # type: ignore[possibly-undefined]
110- )
111- tqutils .insert_rescale_op_to_int8 (
112- tosa_graph , mul_output , output_scale , node , self .tosa_spec
113- )
114- elif output .dtype == ts .DType .INT16 :
115- # Scale output back to 16 bit
116- output_scale = (
117- input_A_qargs .get_scale_per_tensor () # type: ignore[possibly-undefined]
118- * input_B_qargs .get_scale_per_tensor () # type: ignore[possibly-undefined]
119- )
120- tqutils .insert_rescale_op_to_int16 (
121- tosa_graph , mul_output , output_scale , node , self .tosa_spec
122- )
123-
124-
125- @register_node_visitor
126- class MulVisitor_FP (MulVisitor_INT ):
127- # inheriting 'target' from INT class
128-
129- tosa_specs = [TosaSpecification .create_from_string ("TOSA-1.0+FP" )]
130-
131- def define_node (
132- self ,
133- node : torch .fx .Node ,
134- tosa_graph : Any ,
135- inputs : List [TosaArg ],
136- output : TosaArg ,
137- ) -> None :
138-
139- import serializer .tosa_serializer as ts # type: ignore
140-
141- validate_num_inputs (self .target , inputs , 2 )
142- validate_same_dtype (self .target , [* inputs , output ], ts )
143-
144- if inputs [0 ].dtype == ts .DType .INT8 :
145- return super ().define_node (node , tosa_graph , inputs , output )
146-
147- input1 , input2 = inputs
148-
14953 tosa_graph .addConst ([1 ], ts .DType .INT8 , 0 , name = f"{ node .name } _shift" )
150- self ._serialize_operator (
151- node ,
152- tosa_graph ,
54+ tosa_graph .addOperator (
15355 ts .TosaOp .Op ().MUL ,
154- [input1 .name , input2 .name , f"{ node .name } _shift" ],
56+ [inputs [ 0 ] .name , inputs [ 1 ] .name , f"{ node .name } _shift" ],
15557 [output .name ],
15658 )
0 commit comments