1313
1414from executorch .backends .arm ._passes .fold_qdq_with_annotated_qparams_pass import (
1515 get_input_qparams ,
16- get_output_qparams ,
1716)
1817from executorch .backends .arm .operators .node_visitor import (
1918 NodeVisitor ,
2625)
2726from executorch .backends .arm .tosa import TosaSpecification
2827from executorch .backends .arm .tosa .mapping import TosaArg
29- from executorch .backends .arm .tosa .quant_utils import build_rescale
30- from tosa .RoundingMode import RoundingMode # type: ignore
3128
3229
3330@register_node_visitor
34- class BMMVisitor (NodeVisitor ):
35- """Provide a visitor that lowers ``aten.bmm`` to TOSA ``MATMUL``.
31+ class MatmulVisitor (NodeVisitor ):
32+ """Provide a visitor that serializes TOSA ``MATMUL``."""
3633
37- INT8 accumulates into INT32; add a rescale to INT8 using SINGLE_ROUND
38- rounding and output zero-point.
39-
40- """
41-
42- target = "aten.bmm.default"
34+ target = "tosa.MATMUL.default"
4335
4436 tosa_specs = [
4537 TosaSpecification .create_from_string ("TOSA-1.0+INT" ),
@@ -56,35 +48,36 @@ def define_node(
5648 inputs : List [TosaArg ],
5749 output : TosaArg ,
5850 ) -> None :
59- """Define the TOSA ``MATMUL`` operator and optional rescale ."""
51+ """Define the TOSA ``MATMUL`` operator."""
6052 import serializer .tosa_serializer as ts # type: ignore
6153
6254 validate_num_inputs (self .target , inputs , 2 )
63- validate_same_dtype (self .target , [* inputs , output ], ts )
55+ validate_same_dtype (self .target , [* inputs ], ts )
6456 validate_valid_dtype (
6557 self .target ,
66- [* inputs , output ],
67- [ts .DType .INT8 , ts .DType .INT16 , ts .DType .FP32 ],
58+ [* inputs ],
59+ [ts .DType .INT8 , ts .DType .FP32 ],
60+ output .tosa_spec ,
61+ )
62+ validate_valid_dtype (
63+ self .target ,
64+ [output ],
65+ [ts .DType .INT32 , ts .DType .FP32 ],
6866 output .tosa_spec ,
6967 )
7068
71- # aten.bmm maps directly to MATMUL
72-
73- # For INT8, we need to get the zero points and add an intermediate tensor
74- # for a later rescale.
75-
69+ # We need to get the zero points and add an intermediate tensor
7670 if inputs [0 ].dtype == ts .DType .INT8 :
7771 input_qparams = get_input_qparams (node )
7872 input0_zp = input_qparams [0 ].get_zp_per_tensor ()
7973 input1_zp = input_qparams [1 ].get_zp_per_tensor ()
80- bmm_result = tosa_graph .addIntermediate (output .shape , ts .DType .INT32 )
81- bmm_output_name = bmm_result .name
8274 else :
83- bmm_output_name = output .name
8475 input0_zp , input1_zp = 0 , 0
8576
86- tosa_graph .addConst ([1 ], inputs [0 ].dtype , [input0_zp ], name = f"{ node .name } _A_ZP" )
87- tosa_graph .addConst ([1 ], inputs [1 ].dtype , [input1_zp ], name = f"{ node .name } _B_ZP" )
77+ input_A_ZP_name = f"{ node .name } _A_ZP"
78+ input_B_ZP_name = f"{ node .name } _B_ZP"
79+ tosa_graph .addConst ([1 ], inputs [0 ].dtype , [input0_zp ], name = input_A_ZP_name )
80+ tosa_graph .addConst ([1 ], inputs [1 ].dtype , [input1_zp ], name = input_B_ZP_name )
8881
8982 # Add the MATMUL to the TOSA graph.
9083 self ._serialize_operator (
@@ -94,27 +87,8 @@ def define_node(
9487 [
9588 inputs [0 ].name ,
9689 inputs [1 ].name ,
97- f" { node . name } _A_ZP" ,
98- f" { node . name } _B_ZP" ,
90+ input_A_ZP_name ,
91+ input_B_ZP_name ,
9992 ],
100- [bmm_output_name ],
93+ [output . name ],
10194 )
102-
103- # As INT8 accumulates into INT32, we need to rescale it back to INT8
104- if output .dtype == ts .DType .INT8 :
105- output_qparams = get_output_qparams (node )[0 ]
106- final_output_scale = (
107- input_qparams [0 ].get_scale_per_tensor () * input_qparams [1 ].get_scale_per_tensor () # type: ignore[possibly-undefined] # pyre-ignore[61]
108- ) / output_qparams .get_scale_per_tensor ()
109-
110- build_rescale (
111- tosa_fb = tosa_graph ,
112- scale = [final_output_scale ],
113- # pyre-ignore[61]: Uninitialized local [61]: Local variable `bmm_result` is undefined, or not always defined.
114- input_node = bmm_result , # type: ignore[possibly-undefined]
115- output_name = output .name ,
116- output_type = ts .DType .INT8 ,
117- input_zp = [0 ],
118- output_zp = [output_qparams .get_zp_per_tensor ()],
119- rounding_mode = RoundingMode .SINGLE_ROUND ,
120- )
0 commit comments