@@ -50,38 +50,50 @@ def define_node(
5050 validate_num_inputs (self .target , inputs , 2 )
5151 validate_same_dtype (self .target , [* inputs , output ], ts )
5252 validate_valid_dtype (
53- self .target , [* inputs , output ], ts .DType .INT8 , output .tosa_spec
53+ self .target ,
54+ [* inputs , output ],
55+ [ts .DType .INT8 , ts .DType .INT32 ],
56+ output .tosa_spec ,
5457 )
5558
5659 dim_order = (
5760 inputs [0 ].dim_order
5861 if len (inputs [0 ].shape ) > len (inputs [1 ].shape )
5962 else inputs [1 ].dim_order
6063 )
61- input_A = inputs [0 ]
62- input_B = inputs [1 ]
63- input_qparams = get_input_qparams (node )
64- input_A_qargs = input_qparams [0 ]
65- input_B_qargs = input_qparams [1 ]
66- input_A .shape = tutils .tosa_shape (input_A .shape , input_A .dim_order )
67- input_B .shape = tutils .tosa_shape (input_B .shape , input_B .dim_order )
68-
69- # Rescale inputs to INT32 with zp=0
70- input_A_rescaled = tqutils .build_rescale_to_int32 (
71- tosa_graph ,
72- input_A ,
73- input_A_qargs .get_zp_per_tensor (),
74- 1.0 ,
75- )
76- input_B_rescaled = tqutils .build_rescale_to_int32 (
77- tosa_graph ,
78- input_B ,
79- input_B_qargs .get_zp_per_tensor (),
80- 1.0 ,
81- )
82-
83- output_shape = tutils .tosa_shape (output .shape , output .dim_order )
84- mul_output = tosa_graph .addIntermediate (output_shape , ts .DType .INT32 )
64+ if inputs [0 ].dtype == ts .DType .INT8 :
65+ input_A = inputs [0 ]
66+ input_B = inputs [1 ]
67+ input_qparams = get_input_qparams (node )
68+ input_A_qargs = input_qparams [0 ]
69+ input_B_qargs = input_qparams [1 ]
70+ input_A .shape = tutils .tosa_shape (input_A .shape , input_A .dim_order )
71+ input_B .shape = tutils .tosa_shape (input_B .shape , input_B .dim_order )
72+
73+ # Rescale inputs to INT32 with zp=0
74+ input_A_rescaled = tqutils .build_rescale_to_int32 (
75+ tosa_graph ,
76+ input_A ,
77+ input_A_qargs .get_zp_per_tensor (),
78+ 1.0 ,
79+ )
80+ input_B_rescaled = tqutils .build_rescale_to_int32 (
81+ tosa_graph ,
82+ input_B ,
83+ input_B_qargs .get_zp_per_tensor (),
84+ 1.0 ,
85+ )
86+ else :
87+ # input[0].dtype == ts.DType.INT32
88+ # Non quantized input, natively support by TOSA.MUL
89+ input_A_rescaled , input_B_rescaled = inputs [0 ], inputs [1 ]
90+
91+ if output .dtype == ts .DType .INT8 :
92+ output_shape = tutils .tosa_shape (output .shape , output .dim_order )
93+ mul_output = tosa_graph .addIntermediate (output_shape , ts .DType .INT32 )
94+ else :
95+ # output.dtype == ts.DType.INT32
96+ mul_output = output
8597
8698 input1 , input2 = tutils .reshape_for_broadcast (
8799 tosa_graph ,
@@ -101,10 +113,16 @@ def define_node(
101113 [mul_output .name ],
102114 attr ,
103115 )
104- output_scale = (
105- input_A_qargs .get_scale_per_tensor () * input_B_qargs .get_scale_per_tensor ()
106- )
107- tqutils .insert_rescale_op_to_int8 (tosa_graph , mul_output , output_scale , node )
116+
117+ if output .dtype == ts .DType .INT8 :
118+ # Scale output back to 8 bit
119+ output_scale = (
120+ input_A_qargs .get_scale_per_tensor () # type: ignore[possibly-undefined]
121+ * input_B_qargs .get_scale_per_tensor () # type: ignore[possibly-undefined]
122+ )
123+ tqutils .insert_rescale_op_to_int8 (
124+ tosa_graph , mul_output , output_scale , node
125+ )
108126
109127
110128@register_node_visitor
@@ -161,35 +179,47 @@ def define_node(
161179 validate_num_inputs (self .target , inputs , 2 )
162180 validate_same_dtype (self .target , [* inputs , output ], ts )
163181 validate_valid_dtype (
164- self .target , [* inputs , output ], ts .DType .INT8 , output .tosa_spec
165- )
166-
167- input_A = inputs [0 ]
168- input_B = inputs [1 ]
169- input_qparams = get_input_qparams (node )
170- input_A_qargs = input_qparams [0 ]
171- input_B_qargs = input_qparams [1 ]
172- input_A .shape = tutils .tosa_shape (input_A .shape , input_A .dim_order )
173- input_B .shape = tutils .tosa_shape (input_B .shape , input_B .dim_order )
174-
175- # Rescale inputs to INT32 with zp=0
176- input_A_rescaled = tqutils .build_rescale_to_int32 (
177- tosa_graph ,
178- input_A ,
179- input_A_qargs .get_zp_per_tensor (),
180- 1.0 ,
181- tosa_spec = self .tosa_spec ,
182- )
183- input_B_rescaled = tqutils .build_rescale_to_int32 (
184- tosa_graph ,
185- input_B ,
186- input_B_qargs .get_zp_per_tensor (),
187- 1.0 ,
188- tosa_spec = self .tosa_spec ,
182+ self .target ,
183+ [* inputs , output ],
184+ [ts .DType .INT8 , ts .DType .INT32 ],
185+ output .tosa_spec ,
189186 )
190187
191- output_shape = tutils .tosa_shape (output .shape , output .dim_order )
192- mul_output = tosa_graph .addIntermediate (output_shape , ts .DType .INT32 )
188+ if inputs [0 ].dtype == ts .DType .INT8 :
189+ input_A = inputs [0 ]
190+ input_B = inputs [1 ]
191+ input_qparams = get_input_qparams (node )
192+ input_A_qargs = input_qparams [0 ]
193+ input_B_qargs = input_qparams [1 ]
194+ input_A .shape = tutils .tosa_shape (input_A .shape , input_A .dim_order )
195+ input_B .shape = tutils .tosa_shape (input_B .shape , input_B .dim_order )
196+
197+ # Rescale inputs to INT32 with zp=0
198+ input_A_rescaled = tqutils .build_rescale_to_int32 (
199+ tosa_graph ,
200+ input_A ,
201+ input_A_qargs .get_zp_per_tensor (),
202+ 1.0 ,
203+ tosa_spec = self .tosa_spec ,
204+ )
205+ input_B_rescaled = tqutils .build_rescale_to_int32 (
206+ tosa_graph ,
207+ input_B ,
208+ input_B_qargs .get_zp_per_tensor (),
209+ 1.0 ,
210+ tosa_spec = self .tosa_spec ,
211+ )
212+ else :
213+ # input[0].dtype == ts.DType.INT32
214+ # Non quantized input, natively support by TOSA.MUL
215+ input_A_rescaled , input_B_rescaled = inputs [0 ], inputs [1 ]
216+
217+ if output .dtype == ts .DType .INT8 :
218+ output_shape = tutils .tosa_shape (output .shape , output .dim_order )
219+ mul_output = tosa_graph .addIntermediate (output_shape , ts .DType .INT32 )
220+ else :
221+ # output.dtype == ts.DType.INT32
222+ mul_output = output
193223
194224 # Do the INT32 Mul
195225 tosa_graph .addConst ([1 ], ts .DType .INT8 , 0 , name = f"{ node .name } _shift" )
@@ -198,12 +228,16 @@ def define_node(
198228 [input_A_rescaled .name , input_B_rescaled .name , f"{ node .name } _shift" ],
199229 [mul_output .name ],
200230 )
201- output_scale = (
202- input_A_qargs .get_scale_per_tensor () * input_B_qargs .get_scale_per_tensor ()
203- )
204- tqutils .insert_rescale_op_to_int8 (
205- tosa_graph , mul_output , output_scale , node , self .tosa_spec
206- )
231+
232+ if output .dtype == ts .DType .INT8 :
233+ # Scale output back to 8 bit
234+ output_scale = (
235+ input_A_qargs .get_scale_per_tensor () # type: ignore[possibly-undefined]
236+ * input_B_qargs .get_scale_per_tensor () # type: ignore[possibly-undefined]
237+ )
238+ tqutils .insert_rescale_op_to_int8 (
239+ tosa_graph , mul_output , output_scale , node , self .tosa_spec
240+ )
207241
208242
209243@register_node_visitor
0 commit comments