55
66# pyre-unsafe
77
8- from typing import cast , List
8+ from typing import List
99
1010import executorch .backends .arm .tosa_quant_utils as tqutils
1111import executorch .backends .arm .tosa_utils as tutils
1212
1313import serializer .tosa_serializer as ts
1414import torch
15+ from executorch .backends .arm ._passes .fold_qdq_with_annotated_qparams_pass import (
16+ get_input_qparams ,
17+ )
1518
1619from executorch .backends .arm .operators .node_visitor import (
1720 NodeVisitor ,
1821 register_node_visitor ,
1922)
2023from executorch .backends .arm .tosa_mapping import TosaArg
24+ from executorch .backends .arm .tosa_specification import TosaSpecification
2125from serializer .tosa_serializer import TosaOp
2226
2327
2428@register_node_visitor
25- class MulVisitor (NodeVisitor ):
29+ class MulVisitor_080_BI (NodeVisitor ):
2630 target = "aten.mul.Tensor"
2731
32+ tosa_specs = [
33+ TosaSpecification .create_from_string ("TOSA-0.80.0+BI" ),
34+ ]
35+
2836 def define_node (
2937 self ,
3038 node : torch .fx .Node ,
@@ -33,57 +41,68 @@ def define_node(
3341 output : TosaArg ,
3442 is_quant_node : bool ,
3543 ) -> None :
44+ assert inputs [0 ].dtype == inputs [1 ].dtype == output .dtype == ts .DType .INT8
45+ input_A = inputs [0 ]
46+ input_B = inputs [1 ]
47+ input_qparams = get_input_qparams (node )
48+ input_A_qargs = input_qparams [0 ]
49+ input_B_qargs = input_qparams [1 ]
50+ input_A .shape = tutils .tosa_shape (input_A .shape , input_A .dim_order )
51+ input_B .shape = tutils .tosa_shape (input_B .shape , input_B .dim_order )
52+
53+ # Rescale inputs to INT32 with zp=0
54+ input_A_rescaled = tqutils .build_rescale_to_int32 (
55+ tosa_graph ,
56+ input_A ,
57+ input_A_qargs .zp ,
58+ rescale_scale = 1.0 ,
59+ )
60+ input_B_rescaled = tqutils .build_rescale_to_int32 (
61+ tosa_graph ,
62+ input_B ,
63+ input_B_qargs .zp ,
64+ rescale_scale = 1.0 ,
65+ )
66+
67+ output_shape = tutils .tosa_shape (output .shape , output .dim_order )
68+ mul_output = tosa_graph .addIntermediate (output_shape , ts .DType .INT32 )
69+
70+ # Do the INT32 Mul
71+ attr = ts .TosaSerializerAttribute ()
72+ attr .MulAttribute (shift = 0 )
73+ tosa_graph .addOperator (
74+ TosaOp .Op ().MUL ,
75+ [
76+ input_A_rescaled .name ,
77+ input_B_rescaled .name ,
78+ ],
79+ [mul_output .name ],
80+ attr ,
81+ )
82+ output_scale = input_A_qargs .scale * input_B_qargs .scale
83+ tqutils .insert_rescale_op_to_int8 (tosa_graph , mul_output , output_scale , node )
84+
85+
86+ @register_node_visitor
87+ class MulVisitor_080_MI (MulVisitor_080_BI ):
88+ # inheriting 'target' from BI class
89+
90+ tosa_specs = [
91+ TosaSpecification .create_from_string ("TOSA-0.80.0+MI" ),
92+ ]
3693
37- if is_quant_node :
38- input_A = inputs [0 ]
39- input_B = inputs [1 ]
40- input_A_qargs = tqutils .get_quant_arg_upstream (
41- cast (torch .fx .Node , node .args [0 ])
42- )
43- input_B_qargs = tqutils .get_quant_arg_upstream (
44- cast (torch .fx .Node , node .args [1 ])
45- )
46-
47- input_A .shape = tutils .tosa_shape (input_A .shape , input_A .dim_order )
48- input_B .shape = tutils .tosa_shape (input_B .shape , input_B .dim_order )
49- output_shape = tutils .tosa_shape (output .shape , output .dim_order )
50-
51- # Rescale inputs to INT32 with zp=0
52- input_A_rescaled = tqutils .build_rescale_to_int32 (
53- tosa_graph ,
54- input_A ,
55- input_A_qargs .zp ,
56- rescale_scale = 1.0 ,
57- )
58- input_B_rescaled = tqutils .build_rescale_to_int32 (
59- tosa_graph ,
60- input_B ,
61- input_B_qargs .zp ,
62- rescale_scale = 1.0 ,
63- )
64-
65- mul_output = tosa_graph .addIntermediate (output_shape , ts .DType .INT32 )
66-
67- # Do the INT32 Mul
68- attr = ts .TosaSerializerAttribute ()
69- attr .MulAttribute (shift = 0 )
70- tosa_graph .addOperator (
71- TosaOp .Op ().MUL ,
72- [
73- input_A_rescaled .name ,
74- input_B_rescaled .name ,
75- ],
76- [mul_output .name ],
77- attr ,
78- )
79-
80- tqutils .rescale_node_back_to_int8 (
81- node , mul_output , input_A_qargs .scale * input_B_qargs .scale , tosa_graph
82- )
83-
84- else :
85- attr = ts .TosaSerializerAttribute ()
86- attr .MulAttribute (shift = 0 )
87- tosa_graph .addOperator (
88- TosaOp .Op ().MUL , [inputs [0 ].name , inputs [1 ].name ], [output .name ], attr
89- )
94+ def define_node (
95+ self ,
96+ node : torch .fx .Node ,
97+ tosa_graph : ts .TosaSerializer ,
98+ inputs : List [TosaArg ],
99+ output : TosaArg ,
100+ is_quant_node : bool ,
101+ ) -> None :
102+ if inputs [0 ].dtype == ts .DType .INT8 :
103+ return super ().define_node (node , tosa_graph , inputs , output , is_quant_node )
104+ attr = ts .TosaSerializerAttribute ()
105+ attr .MulAttribute (shift = 0 )
106+ tosa_graph .addOperator (
107+ TosaOp .Op ().MUL , [inputs [0 ].name , inputs [1 ].name ], [output .name ], attr
108+ )
0 commit comments