@@ -50,38 +50,50 @@ def define_node(
50
50
validate_num_inputs (self .target , inputs , 2 )
51
51
validate_same_dtype (self .target , [* inputs , output ], ts )
52
52
validate_valid_dtype (
53
- self .target , [* inputs , output ], ts .DType .INT8 , output .tosa_spec
53
+ self .target ,
54
+ [* inputs , output ],
55
+ [ts .DType .INT8 , ts .DType .INT32 ],
56
+ output .tosa_spec ,
54
57
)
55
58
56
59
dim_order = (
57
60
inputs [0 ].dim_order
58
61
if len (inputs [0 ].shape ) > len (inputs [1 ].shape )
59
62
else inputs [1 ].dim_order
60
63
)
61
- input_A = inputs [0 ]
62
- input_B = inputs [1 ]
63
- input_qparams = get_input_qparams (node )
64
- input_A_qargs = input_qparams [0 ]
65
- input_B_qargs = input_qparams [1 ]
66
- input_A .shape = tutils .tosa_shape (input_A .shape , input_A .dim_order )
67
- input_B .shape = tutils .tosa_shape (input_B .shape , input_B .dim_order )
68
-
69
- # Rescale inputs to INT32 with zp=0
70
- input_A_rescaled = tqutils .build_rescale_to_int32 (
71
- tosa_graph ,
72
- input_A ,
73
- input_A_qargs .get_zp_per_tensor (),
74
- 1.0 ,
75
- )
76
- input_B_rescaled = tqutils .build_rescale_to_int32 (
77
- tosa_graph ,
78
- input_B ,
79
- input_B_qargs .get_zp_per_tensor (),
80
- 1.0 ,
81
- )
82
-
83
- output_shape = tutils .tosa_shape (output .shape , output .dim_order )
84
- mul_output = tosa_graph .addIntermediate (output_shape , ts .DType .INT32 )
64
+ if inputs [0 ].dtype == ts .DType .INT8 :
65
+ input_A = inputs [0 ]
66
+ input_B = inputs [1 ]
67
+ input_qparams = get_input_qparams (node )
68
+ input_A_qargs = input_qparams [0 ]
69
+ input_B_qargs = input_qparams [1 ]
70
+ input_A .shape = tutils .tosa_shape (input_A .shape , input_A .dim_order )
71
+ input_B .shape = tutils .tosa_shape (input_B .shape , input_B .dim_order )
72
+
73
+ # Rescale inputs to INT32 with zp=0
74
+ input_A_rescaled = tqutils .build_rescale_to_int32 (
75
+ tosa_graph ,
76
+ input_A ,
77
+ input_A_qargs .get_zp_per_tensor (),
78
+ 1.0 ,
79
+ )
80
+ input_B_rescaled = tqutils .build_rescale_to_int32 (
81
+ tosa_graph ,
82
+ input_B ,
83
+ input_B_qargs .get_zp_per_tensor (),
84
+ 1.0 ,
85
+ )
86
+ else :
87
+ # input[0].dtype == ts.DType.INT32
88
+ # Non quantized input, natively support by TOSA.MUL
89
+ input_A_rescaled , input_B_rescaled = inputs [0 ], inputs [1 ]
90
+
91
+ if output .dtype == ts .DType .INT8 :
92
+ output_shape = tutils .tosa_shape (output .shape , output .dim_order )
93
+ mul_output = tosa_graph .addIntermediate (output_shape , ts .DType .INT32 )
94
+ else :
95
+ # output.dtype == ts.DType.INT32
96
+ mul_output = output
85
97
86
98
input1 , input2 = tutils .reshape_for_broadcast (
87
99
tosa_graph ,
@@ -101,10 +113,16 @@ def define_node(
101
113
[mul_output .name ],
102
114
attr ,
103
115
)
104
- output_scale = (
105
- input_A_qargs .get_scale_per_tensor () * input_B_qargs .get_scale_per_tensor ()
106
- )
107
- tqutils .insert_rescale_op_to_int8 (tosa_graph , mul_output , output_scale , node )
116
+
117
+ if output .dtype == ts .DType .INT8 :
118
+ # Scale output back to 8 bit
119
+ output_scale = (
120
+ input_A_qargs .get_scale_per_tensor () # type: ignore[possibly-undefined]
121
+ * input_B_qargs .get_scale_per_tensor () # type: ignore[possibly-undefined]
122
+ )
123
+ tqutils .insert_rescale_op_to_int8 (
124
+ tosa_graph , mul_output , output_scale , node
125
+ )
108
126
109
127
110
128
@register_node_visitor
@@ -161,35 +179,47 @@ def define_node(
161
179
validate_num_inputs (self .target , inputs , 2 )
162
180
validate_same_dtype (self .target , [* inputs , output ], ts )
163
181
validate_valid_dtype (
164
- self .target , [* inputs , output ], ts .DType .INT8 , output .tosa_spec
165
- )
166
-
167
- input_A = inputs [0 ]
168
- input_B = inputs [1 ]
169
- input_qparams = get_input_qparams (node )
170
- input_A_qargs = input_qparams [0 ]
171
- input_B_qargs = input_qparams [1 ]
172
- input_A .shape = tutils .tosa_shape (input_A .shape , input_A .dim_order )
173
- input_B .shape = tutils .tosa_shape (input_B .shape , input_B .dim_order )
174
-
175
- # Rescale inputs to INT32 with zp=0
176
- input_A_rescaled = tqutils .build_rescale_to_int32 (
177
- tosa_graph ,
178
- input_A ,
179
- input_A_qargs .get_zp_per_tensor (),
180
- 1.0 ,
181
- tosa_spec = self .tosa_spec ,
182
- )
183
- input_B_rescaled = tqutils .build_rescale_to_int32 (
184
- tosa_graph ,
185
- input_B ,
186
- input_B_qargs .get_zp_per_tensor (),
187
- 1.0 ,
188
- tosa_spec = self .tosa_spec ,
182
+ self .target ,
183
+ [* inputs , output ],
184
+ [ts .DType .INT8 , ts .DType .INT32 ],
185
+ output .tosa_spec ,
189
186
)
190
187
191
- output_shape = tutils .tosa_shape (output .shape , output .dim_order )
192
- mul_output = tosa_graph .addIntermediate (output_shape , ts .DType .INT32 )
188
+ if inputs [0 ].dtype == ts .DType .INT8 :
189
+ input_A = inputs [0 ]
190
+ input_B = inputs [1 ]
191
+ input_qparams = get_input_qparams (node )
192
+ input_A_qargs = input_qparams [0 ]
193
+ input_B_qargs = input_qparams [1 ]
194
+ input_A .shape = tutils .tosa_shape (input_A .shape , input_A .dim_order )
195
+ input_B .shape = tutils .tosa_shape (input_B .shape , input_B .dim_order )
196
+
197
+ # Rescale inputs to INT32 with zp=0
198
+ input_A_rescaled = tqutils .build_rescale_to_int32 (
199
+ tosa_graph ,
200
+ input_A ,
201
+ input_A_qargs .get_zp_per_tensor (),
202
+ 1.0 ,
203
+ tosa_spec = self .tosa_spec ,
204
+ )
205
+ input_B_rescaled = tqutils .build_rescale_to_int32 (
206
+ tosa_graph ,
207
+ input_B ,
208
+ input_B_qargs .get_zp_per_tensor (),
209
+ 1.0 ,
210
+ tosa_spec = self .tosa_spec ,
211
+ )
212
+ else :
213
+ # input[0].dtype == ts.DType.INT32
214
+ # Non quantized input, natively support by TOSA.MUL
215
+ input_A_rescaled , input_B_rescaled = inputs [0 ], inputs [1 ]
216
+
217
+ if output .dtype == ts .DType .INT8 :
218
+ output_shape = tutils .tosa_shape (output .shape , output .dim_order )
219
+ mul_output = tosa_graph .addIntermediate (output_shape , ts .DType .INT32 )
220
+ else :
221
+ # output.dtype == ts.DType.INT32
222
+ mul_output = output
193
223
194
224
# Do the INT32 Mul
195
225
tosa_graph .addConst ([1 ], ts .DType .INT8 , 0 , name = f"{ node .name } _shift" )
@@ -198,12 +228,16 @@ def define_node(
198
228
[input_A_rescaled .name , input_B_rescaled .name , f"{ node .name } _shift" ],
199
229
[mul_output .name ],
200
230
)
201
- output_scale = (
202
- input_A_qargs .get_scale_per_tensor () * input_B_qargs .get_scale_per_tensor ()
203
- )
204
- tqutils .insert_rescale_op_to_int8 (
205
- tosa_graph , mul_output , output_scale , node , self .tosa_spec
206
- )
231
+
232
+ if output .dtype == ts .DType .INT8 :
233
+ # Scale output back to 8 bit
234
+ output_scale = (
235
+ input_A_qargs .get_scale_per_tensor () # type: ignore[possibly-undefined]
236
+ * input_B_qargs .get_scale_per_tensor () # type: ignore[possibly-undefined]
237
+ )
238
+ tqutils .insert_rescale_op_to_int8 (
239
+ tosa_graph , mul_output , output_scale , node , self .tosa_spec
240
+ )
207
241
208
242
209
243
@register_node_visitor
0 commit comments