Skip to content

Commit 6358b0a

Browse files
committed
Arm backend: Add sub tensor to match scalar tensors to input data types.
Change-Id: I14fe181345f5c6034c0e230329425f5ba74f4910
1 parent 07c8f0f commit 6358b0a

File tree

5 files changed

+19
-25
lines changed

5 files changed

+19
-25
lines changed

backends/arm/_passes/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,8 @@
6767
)
6868
from .insert_rescales_pass import InsertRescalePass # noqa
6969
from .insert_table_ops import InsertTableOpsPass # noqa
70+
from .match_arg_dtype_pass import MatchArgDtypePass # noqa
7071
from .match_arg_ranks_pass import MatchArgRanksPass # noqa
71-
from .match_where_self_arg_dtype_pass import MatchWhereSelfDtypePass # noqa
7272
from .mm_to_bmm_pass import ConvertMmToBmmPass # noqa
7373
from .remove_clone_pass import RemoveClonePass # noqa
7474
from .replace_scalar_with_tensor_pass import ( # noqa

backends/arm/_passes/arm_pass_manager.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,8 @@
6666
InsertCastForOpsWithInt64InputPass,
6767
InsertRescalePass,
6868
InsertTableOpsPass,
69+
MatchArgDtypePass,
6970
MatchArgRanksPass,
70-
MatchWhereSelfDtypePass,
7171
QuantizeOperatorArguments,
7272
RemoveClonePass,
7373
ReplaceInfValues,
@@ -116,7 +116,7 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
116116
self.add_pass(ConvertToClampPass())
117117
self.add_pass(ConvertMinMaxPass())
118118
self.add_pass(ConvertAnyDefaultDimDimsPass())
119-
self.add_pass(MatchWhereSelfDtypePass())
119+
self.add_pass(MatchArgDtypePass())
120120
if self.tosa_spec.is_U55_subset:
121121
self.add_pass(CastToInt32Pass())
122122

@@ -193,8 +193,7 @@ def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
193193
self.add_pass(ConvertToClampPass())
194194
self.add_pass(ConvertMinMaxPass())
195195
self.add_pass(ConvertAnyDefaultDimDimsPass())
196-
self.add_pass(MatchWhereSelfDtypePass())
197-
196+
self.add_pass(MatchArgDtypePass())
198197
self.add_pass(AnnotateDecomposedMatmulPass())
199198
self.add_pass(QuantizeOperatorArguments())
200199
self.add_pass(FoldAndAnnotateQParamsPass(exported_program)) # type: ignore[call-arg]

backends/arm/_passes/match_where_self_arg_dtype_pass.py renamed to backends/arm/_passes/match_arg_dtype_pass.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# LICENSE file in the root directory of this source tree.
55

66
import torch
7-
from executorch.backends.arm._passes.arm_pass_utils import create_node
7+
from executorch.backends.arm._passes.arm_pass_utils import create_node, get_node_arg
88
from executorch.exir.dialects._ops import ops as exir_ops
99
from executorch.exir.pass_base import ExportPass, PassResult
1010

@@ -26,7 +26,7 @@ def get_largest_dtype(dtype_1, dtype_2):
2626
return dtype_1 if DTYPE_RANK[dtype_1] > DTYPE_RANK[dtype_2] else dtype_2
2727

2828

29-
class MatchWhereSelfDtypePass(ExportPass):
29+
class MatchArgDtypePass(ExportPass):
3030
"""Pass to match data types of non-condition input tensors.
3131
3232
Edge dialect allows different data types for non-condition tensors, while TOSA
@@ -38,14 +38,18 @@ class MatchWhereSelfDtypePass(ExportPass):
3838
3939
"""
4040

41+
targeted_ops = {exir_ops.edge.aten.sub.Tensor, exir_ops.edge.aten.where.self}
42+
4143
def call(self, graph_module: torch.fx.GraphModule):
4244
modified_graph = False
4345
graph = graph_module.graph
44-
node_list = graph.find_nodes(
45-
op="call_function", target=exir_ops.edge.aten.where.self
46-
)
47-
for node in node_list:
48-
cond, input_, other_ = node.args
46+
47+
for node in list(graph.nodes):
48+
if node.op != "call_function" or node.target not in self.targeted_ops:
49+
continue
50+
51+
input_ = get_node_arg(node.args, 0)
52+
other_ = get_node_arg(node.args, 1)
4953

5054
input_dtype = input_.meta["val"].dtype
5155
other_dtype = other_.meta["val"].dtype

backends/arm/test/ops/test_scalars.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -242,21 +242,16 @@ def test_add_scalar_u85_BI():
242242

243243

244244
# SUB MI ------------------------------------------------------
245-
mi_sub_xfails = {
246-
"int_r1_ts": "TypeError: All IO needs to have the same data type, got input 1: 8, input 2: 6 and output: 8",
247-
"int_r4_ts": "TypeError: All IO needs to have the same data type, got input 1: 8, input 2: 6 and output: 8",
248-
**xfails,
249-
}
250245

251246

252-
@common.parametrize("test_data", tensor_scalar_tests, xfails=mi_sub_xfails)
247+
@common.parametrize("test_data", tensor_scalar_tests, xfails=xfails)
253248
def test_sub_tensor_tosa_MI_scalar(test_data):
254249
"""Tests regular sub with one scalar input."""
255250
pipeline = TosaPipelineMI[input_t1](Sub(), test_data, aten_op=Sub.aten_op)
256251
pipeline.run()
257252

258253

259-
@common.parametrize("test_data", tensor_scalar_tests, xfails=mi_sub_xfails)
254+
@common.parametrize("test_data", tensor_scalar_tests, xfails=xfails)
260255
def test_sub_tensor_tosa_MI_inplace(test_data):
261256
"""Tests inplace sub with one scalar input."""
262257
pipeline = TosaPipelineMI[input_t1](SubInplace(), test_data, aten_op=[])

backends/arm/test/ops/test_sub.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@
4242
torch.randn(1, 4, 4, 1),
4343
torch.randn(1, 1, 4, 4),
4444
),
45+
"rand_3d_rand_Scalar": lambda: (torch.rand(1, 6, 2), torch.rand(1)),
46+
"rand_3d_Scalar": lambda: (torch.rand(1, 6, 2), 1),
4547
}
4648
fvp_sub2_xfails = {"rand_4D_2x2x4x4": "MLETORCH-517 : Multiple batches not supported"}
4749

@@ -93,7 +95,6 @@ def test_sub_tensor_tosa_BI(test_data):
9395
aten_op,
9496
exir_op,
9597
)
96-
pipeline.change_args("run_method_and_compare_outputs", qtol=1)
9798
pipeline.run()
9899

99100

@@ -106,7 +107,6 @@ def test_sub_tensor_tosa_BI_2(test_data: Tuple[torch.Tensor, torch.Tensor]):
106107
aten_op,
107108
exir_op,
108109
)
109-
pipeline.change_args("run_method_and_compare_outputs", qtol=1)
110110
pipeline.run()
111111

112112

@@ -121,7 +121,6 @@ def test_sub_tensor_u55_BI(test_data):
121121
exir_op,
122122
run_on_fvp=True,
123123
)
124-
pipeline.change_args("run_method_and_compare_outputs", qtol=1)
125124
pipeline.run()
126125

127126

@@ -136,7 +135,6 @@ def test_sub_tensor_u55_BI_2(test_data: Tuple[torch.Tensor, torch.Tensor]):
136135
exir_op,
137136
run_on_fvp=True,
138137
)
139-
pipeline.change_args("run_method_and_compare_outputs", qtol=1)
140138
pipeline.run()
141139

142140

@@ -151,7 +149,6 @@ def test_sub_tensor_u85_BI_2(test_data):
151149
exir_op,
152150
run_on_fvp=True,
153151
)
154-
pipeline.change_args("run_method_and_compare_outputs", qtol=1)
155152
pipeline.run()
156153

157154

@@ -166,5 +163,4 @@ def test_sub_tensor_u85_BI(test_data: Tuple[torch.Tensor, torch.Tensor]):
166163
exir_op,
167164
run_on_fvp=True,
168165
)
169-
pipeline.change_args("run_method_and_compare_outputs", qtol=1)
170166
pipeline.run()

0 commit comments

Comments
 (0)