Skip to content

Commit 5caf20d

Browse files
YufengShi-duduhinriksnaer
authored andcommitted
Arm backend: Add int32 support to aten.mul.Tensor with BI/INT profile (pytorch#11964)
- The current aten.mul.Tensor node visitor only supports INT8 data type with BI/INT profile. This patch adds int32 support to the mul node visitor with BI/INT profile. - However, tests with int32 inputs that require broadcasting fail on u55 and u85. Signed-off-by: Yufeng Shi <[email protected]>
1 parent bafd951 commit 5caf20d

File tree

2 files changed

+180
-62
lines changed

2 files changed

+180
-62
lines changed

backends/arm/operators/op_mul.py

Lines changed: 96 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -50,38 +50,50 @@ def define_node(
5050
validate_num_inputs(self.target, inputs, 2)
5151
validate_same_dtype(self.target, [*inputs, output], ts)
5252
validate_valid_dtype(
53-
self.target, [*inputs, output], ts.DType.INT8, output.tosa_spec
53+
self.target,
54+
[*inputs, output],
55+
[ts.DType.INT8, ts.DType.INT32],
56+
output.tosa_spec,
5457
)
5558

5659
dim_order = (
5760
inputs[0].dim_order
5861
if len(inputs[0].shape) > len(inputs[1].shape)
5962
else inputs[1].dim_order
6063
)
61-
input_A = inputs[0]
62-
input_B = inputs[1]
63-
input_qparams = get_input_qparams(node)
64-
input_A_qargs = input_qparams[0]
65-
input_B_qargs = input_qparams[1]
66-
input_A.shape = tutils.tosa_shape(input_A.shape, input_A.dim_order)
67-
input_B.shape = tutils.tosa_shape(input_B.shape, input_B.dim_order)
68-
69-
# Rescale inputs to INT32 with zp=0
70-
input_A_rescaled = tqutils.build_rescale_to_int32(
71-
tosa_graph,
72-
input_A,
73-
input_A_qargs.get_zp_per_tensor(),
74-
1.0,
75-
)
76-
input_B_rescaled = tqutils.build_rescale_to_int32(
77-
tosa_graph,
78-
input_B,
79-
input_B_qargs.get_zp_per_tensor(),
80-
1.0,
81-
)
82-
83-
output_shape = tutils.tosa_shape(output.shape, output.dim_order)
84-
mul_output = tosa_graph.addIntermediate(output_shape, ts.DType.INT32)
64+
if inputs[0].dtype == ts.DType.INT8:
65+
input_A = inputs[0]
66+
input_B = inputs[1]
67+
input_qparams = get_input_qparams(node)
68+
input_A_qargs = input_qparams[0]
69+
input_B_qargs = input_qparams[1]
70+
input_A.shape = tutils.tosa_shape(input_A.shape, input_A.dim_order)
71+
input_B.shape = tutils.tosa_shape(input_B.shape, input_B.dim_order)
72+
73+
# Rescale inputs to INT32 with zp=0
74+
input_A_rescaled = tqutils.build_rescale_to_int32(
75+
tosa_graph,
76+
input_A,
77+
input_A_qargs.get_zp_per_tensor(),
78+
1.0,
79+
)
80+
input_B_rescaled = tqutils.build_rescale_to_int32(
81+
tosa_graph,
82+
input_B,
83+
input_B_qargs.get_zp_per_tensor(),
84+
1.0,
85+
)
86+
else:
87+
# input[0].dtype == ts.DType.INT32
88+
# Non quantized input, natively support by TOSA.MUL
89+
input_A_rescaled, input_B_rescaled = inputs[0], inputs[1]
90+
91+
if output.dtype == ts.DType.INT8:
92+
output_shape = tutils.tosa_shape(output.shape, output.dim_order)
93+
mul_output = tosa_graph.addIntermediate(output_shape, ts.DType.INT32)
94+
else:
95+
# output.dtype == ts.DType.INT32
96+
mul_output = output
8597

8698
input1, input2 = tutils.reshape_for_broadcast(
8799
tosa_graph,
@@ -101,10 +113,16 @@ def define_node(
101113
[mul_output.name],
102114
attr,
103115
)
104-
output_scale = (
105-
input_A_qargs.get_scale_per_tensor() * input_B_qargs.get_scale_per_tensor()
106-
)
107-
tqutils.insert_rescale_op_to_int8(tosa_graph, mul_output, output_scale, node)
116+
117+
if output.dtype == ts.DType.INT8:
118+
# Scale output back to 8 bit
119+
output_scale = (
120+
input_A_qargs.get_scale_per_tensor() # type: ignore[possibly-undefined]
121+
* input_B_qargs.get_scale_per_tensor() # type: ignore[possibly-undefined]
122+
)
123+
tqutils.insert_rescale_op_to_int8(
124+
tosa_graph, mul_output, output_scale, node
125+
)
108126

109127

110128
@register_node_visitor
@@ -161,35 +179,47 @@ def define_node(
161179
validate_num_inputs(self.target, inputs, 2)
162180
validate_same_dtype(self.target, [*inputs, output], ts)
163181
validate_valid_dtype(
164-
self.target, [*inputs, output], ts.DType.INT8, output.tosa_spec
165-
)
166-
167-
input_A = inputs[0]
168-
input_B = inputs[1]
169-
input_qparams = get_input_qparams(node)
170-
input_A_qargs = input_qparams[0]
171-
input_B_qargs = input_qparams[1]
172-
input_A.shape = tutils.tosa_shape(input_A.shape, input_A.dim_order)
173-
input_B.shape = tutils.tosa_shape(input_B.shape, input_B.dim_order)
174-
175-
# Rescale inputs to INT32 with zp=0
176-
input_A_rescaled = tqutils.build_rescale_to_int32(
177-
tosa_graph,
178-
input_A,
179-
input_A_qargs.get_zp_per_tensor(),
180-
1.0,
181-
tosa_spec=self.tosa_spec,
182-
)
183-
input_B_rescaled = tqutils.build_rescale_to_int32(
184-
tosa_graph,
185-
input_B,
186-
input_B_qargs.get_zp_per_tensor(),
187-
1.0,
188-
tosa_spec=self.tosa_spec,
182+
self.target,
183+
[*inputs, output],
184+
[ts.DType.INT8, ts.DType.INT32],
185+
output.tosa_spec,
189186
)
190187

191-
output_shape = tutils.tosa_shape(output.shape, output.dim_order)
192-
mul_output = tosa_graph.addIntermediate(output_shape, ts.DType.INT32)
188+
if inputs[0].dtype == ts.DType.INT8:
189+
input_A = inputs[0]
190+
input_B = inputs[1]
191+
input_qparams = get_input_qparams(node)
192+
input_A_qargs = input_qparams[0]
193+
input_B_qargs = input_qparams[1]
194+
input_A.shape = tutils.tosa_shape(input_A.shape, input_A.dim_order)
195+
input_B.shape = tutils.tosa_shape(input_B.shape, input_B.dim_order)
196+
197+
# Rescale inputs to INT32 with zp=0
198+
input_A_rescaled = tqutils.build_rescale_to_int32(
199+
tosa_graph,
200+
input_A,
201+
input_A_qargs.get_zp_per_tensor(),
202+
1.0,
203+
tosa_spec=self.tosa_spec,
204+
)
205+
input_B_rescaled = tqutils.build_rescale_to_int32(
206+
tosa_graph,
207+
input_B,
208+
input_B_qargs.get_zp_per_tensor(),
209+
1.0,
210+
tosa_spec=self.tosa_spec,
211+
)
212+
else:
213+
# input[0].dtype == ts.DType.INT32
214+
# Non quantized input, natively support by TOSA.MUL
215+
input_A_rescaled, input_B_rescaled = inputs[0], inputs[1]
216+
217+
if output.dtype == ts.DType.INT8:
218+
output_shape = tutils.tosa_shape(output.shape, output.dim_order)
219+
mul_output = tosa_graph.addIntermediate(output_shape, ts.DType.INT32)
220+
else:
221+
# output.dtype == ts.DType.INT32
222+
mul_output = output
193223

194224
# Do the INT32 Mul
195225
tosa_graph.addConst([1], ts.DType.INT8, 0, name=f"{node.name}_shift")
@@ -198,12 +228,16 @@ def define_node(
198228
[input_A_rescaled.name, input_B_rescaled.name, f"{node.name}_shift"],
199229
[mul_output.name],
200230
)
201-
output_scale = (
202-
input_A_qargs.get_scale_per_tensor() * input_B_qargs.get_scale_per_tensor()
203-
)
204-
tqutils.insert_rescale_op_to_int8(
205-
tosa_graph, mul_output, output_scale, node, self.tosa_spec
206-
)
231+
232+
if output.dtype == ts.DType.INT8:
233+
# Scale output back to 8 bit
234+
output_scale = (
235+
input_A_qargs.get_scale_per_tensor() # type: ignore[possibly-undefined]
236+
* input_B_qargs.get_scale_per_tensor() # type: ignore[possibly-undefined]
237+
)
238+
tqutils.insert_rescale_op_to_int8(
239+
tosa_graph, mul_output, output_scale, node, self.tosa_spec
240+
)
207241

208242

209243
@register_node_visitor

backends/arm/test/ops/test_mul.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,23 @@
7979
}
8080

8181

82+
test_data_suite_int32 = {
83+
# (test_name, input, other,) See torch.mul() for info
84+
"op_mul_rank4_randn_int32": lambda: (
85+
torch.randint(0, 10, (1, 10, 25, 20), dtype=torch.int32),
86+
torch.randint(0, 10, (1, 10, 25, 20), dtype=torch.int32),
87+
),
88+
"op_mul_rank4_randn_mutltiple_broadcasts_int32": lambda: (
89+
torch.randint(0, 10, (1, 4, 4, 1), dtype=torch.int32),
90+
torch.randint(0, 10, (1, 1, 4, 4), dtype=torch.int32),
91+
),
92+
"op_mul_rank4_randn_broadcast_int32": lambda: (
93+
torch.randint(0, 10, (1, 10, 25, 20), dtype=torch.int32),
94+
torch.randint(0, 10, (1, 25, 20), dtype=torch.int32),
95+
),
96+
}
97+
98+
8299
class Mul(torch.nn.Module):
83100

84101
def forward(
@@ -111,6 +128,17 @@ def test_mul_tensor_tosa_MI_diff_input_ranks(test_data: torch.Tensor):
111128
pipeline.run()
112129

113130

131+
@common.parametrize("test_data", test_data_suite_int32)
132+
def test_mul_tensor_tosa_MI_int32(test_data: torch.Tensor):
133+
pipeline = TosaPipelineMI[input_t1](
134+
Mul(),
135+
test_data(),
136+
aten_op,
137+
exir_op=[],
138+
)
139+
pipeline.run()
140+
141+
114142
@common.parametrize("test_data", test_data_suite_2)
115143
def test_mul_tensor_tosa_BI_diff_input_ranks(test_data: torch.Tensor):
116144
pipeline = TosaPipelineBI[input_t1](
@@ -133,6 +161,18 @@ def test_mul_tensor_tosa_BI(test_data: torch.Tensor):
133161
pipeline.run()
134162

135163

164+
@common.parametrize("test_data", test_data_suite_int32)
165+
def test_mul_tensor_tosa_BI_int32(test_data: torch.Tensor):
166+
pipeline = TosaPipelineBI[input_t1](
167+
Mul(),
168+
test_data(),
169+
aten_op,
170+
exir_op=[],
171+
)
172+
pipeline.pop_stage("check.quant_nodes")
173+
pipeline.run()
174+
175+
136176
@common.parametrize("test_data", test_data_suite)
137177
@common.XfailIfNoCorstone300
138178
def test_mul_tensor_u55_BI(test_data: torch.Tensor):
@@ -157,3 +197,47 @@ def test_mul_tensor_u85_BI(test_data: torch.Tensor):
157197
run_on_fvp=True,
158198
)
159199
pipeline.run()
200+
201+
202+
@common.parametrize(
203+
"test_data",
204+
test_data_suite_int32,
205+
xfails={
206+
# TODO: MLETORCH-1132 Investigate why tests with inputs that require broadcasting fail on u55/u85
207+
"op_mul_rank4_randn_mutltiple_broadcasts_int32": "RuntimeError: mean(): could not infer output dtype. Input dtype must be either a floating point or complex dtype. Got: Int",
208+
"op_mul_rank4_randn_broadcast_int32": "RuntimeError: mean(): could not infer output dtype. Input dtype must be either a floating point or complex dtype. Got: Int",
209+
},
210+
)
211+
@common.XfailIfNoCorstone300
212+
def test_mul_tensor_u55_BI_int32(test_data: torch.Tensor):
213+
pipeline = EthosU55PipelineBI[input_t1](
214+
Mul(),
215+
test_data(),
216+
aten_op,
217+
exir_ops=[],
218+
run_on_fvp=True,
219+
)
220+
pipeline.pop_stage("check.quant_nodes")
221+
pipeline.run()
222+
223+
224+
@common.parametrize(
225+
"test_data",
226+
test_data_suite_int32,
227+
xfails={
228+
# TODO: MLETORCH-1132 Investigate why tests with inputs that require broadcasting fail on u55/u85
229+
"op_mul_rank4_randn_mutltiple_broadcasts_int32": "RuntimeError: mean(): could not infer output dtype. Input dtype must be either a floating point or complex dtype. Got: Int",
230+
"op_mul_rank4_randn_broadcast_int32": "RuntimeError: mean(): could not infer output dtype. Input dtype must be either a floating point or complex dtype. Got: Int",
231+
},
232+
)
233+
@common.XfailIfNoCorstone320
234+
def test_mul_tensor_u85_BI_int32(test_data: torch.Tensor):
235+
pipeline = EthosU85PipelineBI[input_t1](
236+
Mul(),
237+
test_data(),
238+
aten_op,
239+
exir_ops=[],
240+
run_on_fvp=True,
241+
)
242+
pipeline.pop_stage("check.quant_nodes")
243+
pipeline.run()

0 commit comments

Comments
 (0)