Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion backends/arm/_passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,8 @@
)
from .insert_rescales_pass import InsertRescalePass # noqa
from .insert_table_ops import InsertTableOpsPass # noqa
from .match_arg_dtype_pass import MatchArgDtypePass # noqa
from .match_arg_ranks_pass import MatchArgRanksPass # noqa
from .match_where_self_arg_dtype_pass import MatchWhereSelfDtypePass # noqa
from .mm_to_bmm_pass import ConvertMmToBmmPass # noqa
from .remove_clone_pass import RemoveClonePass # noqa
from .replace_scalar_with_tensor_pass import ( # noqa
Expand Down
7 changes: 3 additions & 4 deletions backends/arm/_passes/arm_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,8 @@
InsertCastForOpsWithInt64InputPass,
InsertRescalePass,
InsertTableOpsPass,
MatchArgDtypePass,
MatchArgRanksPass,
MatchWhereSelfDtypePass,
QuantizeOperatorArguments,
RemoveClonePass,
ReplaceInfValues,
Expand Down Expand Up @@ -116,7 +116,7 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
self.add_pass(ConvertToClampPass())
self.add_pass(ConvertMinMaxPass())
self.add_pass(ConvertAnyDefaultDimDimsPass())
self.add_pass(MatchWhereSelfDtypePass())
self.add_pass(MatchArgDtypePass())
if self.tosa_spec.is_U55_subset:
self.add_pass(CastToInt32Pass())

Expand Down Expand Up @@ -193,8 +193,7 @@ def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
self.add_pass(ConvertToClampPass())
self.add_pass(ConvertMinMaxPass())
self.add_pass(ConvertAnyDefaultDimDimsPass())
self.add_pass(MatchWhereSelfDtypePass())

self.add_pass(MatchArgDtypePass())
self.add_pass(AnnotateDecomposedMatmulPass())
self.add_pass(QuantizeOperatorArguments())
self.add_pass(FoldAndAnnotateQParamsPass(exported_program)) # type: ignore[call-arg]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# LICENSE file in the root directory of this source tree.

import torch
from executorch.backends.arm._passes.arm_pass_utils import create_node
from executorch.backends.arm._passes.arm_pass_utils import create_node, get_node_arg
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, PassResult

Expand All @@ -26,7 +26,7 @@ def get_largest_dtype(dtype_1, dtype_2):
return dtype_1 if DTYPE_RANK[dtype_1] > DTYPE_RANK[dtype_2] else dtype_2


class MatchWhereSelfDtypePass(ExportPass):
class MatchArgDtypePass(ExportPass):
"""Pass to match data types of non-condition input tensors.

Edge dialect allows different data types for non-condition tensors, while TOSA
Expand All @@ -38,14 +38,18 @@ class MatchWhereSelfDtypePass(ExportPass):

"""

targeted_ops = {exir_ops.edge.aten.sub.Tensor, exir_ops.edge.aten.where.self}

def call(self, graph_module: torch.fx.GraphModule):
modified_graph = False
graph = graph_module.graph
node_list = graph.find_nodes(
op="call_function", target=exir_ops.edge.aten.where.self
)
for node in node_list:
cond, input_, other_ = node.args

for node in list(graph.nodes):
if node.op != "call_function" or node.target not in self.targeted_ops:
continue

input_ = get_node_arg(node.args, 0)
other_ = get_node_arg(node.args, 1)

input_dtype = input_.meta["val"].dtype
other_dtype = other_.meta["val"].dtype
Expand Down
9 changes: 2 additions & 7 deletions backends/arm/test/ops/test_scalars.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,21 +242,16 @@ def test_add_scalar_u85_BI():


# SUB MI ------------------------------------------------------
mi_sub_xfails = {
"int_r1_ts": "TypeError: All IO needs to have the same data type, got input 1: 8, input 2: 6 and output: 8",
"int_r4_ts": "TypeError: All IO needs to have the same data type, got input 1: 8, input 2: 6 and output: 8",
**xfails,
}


@common.parametrize("test_data", tensor_scalar_tests, xfails=mi_sub_xfails)
@common.parametrize("test_data", tensor_scalar_tests, xfails=xfails)
def test_sub_tensor_tosa_MI_scalar(test_data):
"""Tests regular sub with one scalar input."""
pipeline = TosaPipelineMI[input_t1](Sub(), test_data, aten_op=Sub.aten_op)
pipeline.run()


@common.parametrize("test_data", tensor_scalar_tests, xfails=mi_sub_xfails)
@common.parametrize("test_data", tensor_scalar_tests, xfails=xfails)
def test_sub_tensor_tosa_MI_inplace(test_data):
"""Tests inplace sub with one scalar input."""
pipeline = TosaPipelineMI[input_t1](SubInplace(), test_data, aten_op=[])
Expand Down
8 changes: 2 additions & 6 deletions backends/arm/test/ops/test_sub.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@
torch.randn(1, 4, 4, 1),
torch.randn(1, 1, 4, 4),
),
"rand_3d_rand_Scalar": lambda: (torch.rand(1, 6, 2), torch.rand(1)),
"rand_3d_Scalar": lambda: (torch.rand(1, 6, 2), 1),
}
fvp_sub2_xfails = {"rand_4D_2x2x4x4": "MLETORCH-517 : Multiple batches not supported"}

Expand Down Expand Up @@ -93,7 +95,6 @@ def test_sub_tensor_tosa_BI(test_data):
aten_op,
exir_op,
)
pipeline.change_args("run_method_and_compare_outputs", qtol=1)
pipeline.run()


Expand All @@ -106,7 +107,6 @@ def test_sub_tensor_tosa_BI_2(test_data: Tuple[torch.Tensor, torch.Tensor]):
aten_op,
exir_op,
)
pipeline.change_args("run_method_and_compare_outputs", qtol=1)
pipeline.run()


Expand All @@ -121,7 +121,6 @@ def test_sub_tensor_u55_BI(test_data):
exir_op,
run_on_fvp=True,
)
pipeline.change_args("run_method_and_compare_outputs", qtol=1)
pipeline.run()


Expand All @@ -136,7 +135,6 @@ def test_sub_tensor_u55_BI_2(test_data: Tuple[torch.Tensor, torch.Tensor]):
exir_op,
run_on_fvp=True,
)
pipeline.change_args("run_method_and_compare_outputs", qtol=1)
pipeline.run()


Expand All @@ -151,7 +149,6 @@ def test_sub_tensor_u85_BI_2(test_data):
exir_op,
run_on_fvp=True,
)
pipeline.change_args("run_method_and_compare_outputs", qtol=1)
pipeline.run()


Expand All @@ -166,5 +163,4 @@ def test_sub_tensor_u85_BI(test_data: Tuple[torch.Tensor, torch.Tensor]):
exir_op,
run_on_fvp=True,
)
pipeline.change_args("run_method_and_compare_outputs", qtol=1)
pipeline.run()
Loading