diff --git a/backends/arm/_passes/__init__.py b/backends/arm/_passes/__init__.py index 1d6c34a0d35..551abc6b9b3 100644 --- a/backends/arm/_passes/__init__.py +++ b/backends/arm/_passes/__init__.py @@ -67,8 +67,8 @@ ) from .insert_rescales_pass import InsertRescalePass # noqa from .insert_table_ops import InsertTableOpsPass # noqa +from .match_arg_dtype_pass import MatchArgDtypePass # noqa from .match_arg_ranks_pass import MatchArgRanksPass # noqa -from .match_where_self_arg_dtype_pass import MatchWhereSelfDtypePass # noqa from .mm_to_bmm_pass import ConvertMmToBmmPass # noqa from .remove_clone_pass import RemoveClonePass # noqa from .replace_scalar_with_tensor_pass import ( # noqa diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index bfae8f1b017..29c3603fced 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -66,8 +66,8 @@ InsertCastForOpsWithInt64InputPass, InsertRescalePass, InsertTableOpsPass, + MatchArgDtypePass, MatchArgRanksPass, - MatchWhereSelfDtypePass, QuantizeOperatorArguments, RemoveClonePass, ReplaceInfValues, @@ -116,7 +116,7 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul self.add_pass(ConvertToClampPass()) self.add_pass(ConvertMinMaxPass()) self.add_pass(ConvertAnyDefaultDimDimsPass()) - self.add_pass(MatchWhereSelfDtypePass()) + self.add_pass(MatchArgDtypePass()) if self.tosa_spec.is_U55_subset: self.add_pass(CastToInt32Pass()) @@ -193,8 +193,7 @@ def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModul self.add_pass(ConvertToClampPass()) self.add_pass(ConvertMinMaxPass()) self.add_pass(ConvertAnyDefaultDimDimsPass()) - self.add_pass(MatchWhereSelfDtypePass()) - + self.add_pass(MatchArgDtypePass()) self.add_pass(AnnotateDecomposedMatmulPass()) self.add_pass(QuantizeOperatorArguments()) self.add_pass(FoldAndAnnotateQParamsPass(exported_program)) # type: ignore[call-arg] diff --git a/backends/arm/_passes/match_where_self_arg_dtype_pass.py b/backends/arm/_passes/match_arg_dtype_pass.py similarity index 90% rename from backends/arm/_passes/match_where_self_arg_dtype_pass.py rename to backends/arm/_passes/match_arg_dtype_pass.py index fdbd4433bab..e7bf3b2d60e 100644 --- a/backends/arm/_passes/match_where_self_arg_dtype_pass.py +++ b/backends/arm/_passes/match_arg_dtype_pass.py @@ -4,7 +4,7 @@ # LICENSE file in the root directory of this source tree. import torch -from executorch.backends.arm._passes.arm_pass_utils import create_node +from executorch.backends.arm._passes.arm_pass_utils import create_node, get_node_arg from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass, PassResult @@ -26,7 +26,7 @@ def get_largest_dtype(dtype_1, dtype_2): return dtype_1 if DTYPE_RANK[dtype_1] > DTYPE_RANK[dtype_2] else dtype_2 -class MatchWhereSelfDtypePass(ExportPass): +class MatchArgDtypePass(ExportPass): """Pass to match data types of non-condition input tensors. Edge dialect allows different data types for non-condition tensors, while TOSA @@ -38,14 +38,18 @@ class MatchWhereSelfDtypePass(ExportPass): """ + targeted_ops = {exir_ops.edge.aten.sub.Tensor, exir_ops.edge.aten.where.self} + def call(self, graph_module: torch.fx.GraphModule): modified_graph = False graph = graph_module.graph - node_list = graph.find_nodes( - op="call_function", target=exir_ops.edge.aten.where.self - ) - for node in node_list: - cond, input_, other_ = node.args + + for node in list(graph.nodes): + if node.op != "call_function" or node.target not in self.targeted_ops: + continue + + input_ = get_node_arg(node.args, 0) + other_ = get_node_arg(node.args, 1) input_dtype = input_.meta["val"].dtype other_dtype = other_.meta["val"].dtype diff --git a/backends/arm/test/ops/test_scalars.py b/backends/arm/test/ops/test_scalars.py index 7a06f7dfc8d..3ede947b218 100644 --- a/backends/arm/test/ops/test_scalars.py +++ b/backends/arm/test/ops/test_scalars.py @@ -242,21 +242,16 @@ def test_add_scalar_u85_BI(): # SUB MI ------------------------------------------------------ -mi_sub_xfails = { - "int_r1_ts": "TypeError: All IO needs to have the same data type, got input 1: 8, input 2: 6 and output: 8", - "int_r4_ts": "TypeError: All IO needs to have the same data type, got input 1: 8, input 2: 6 and output: 8", - **xfails, -} -@common.parametrize("test_data", tensor_scalar_tests, xfails=mi_sub_xfails) +@common.parametrize("test_data", tensor_scalar_tests, xfails=xfails) def test_sub_tensor_tosa_MI_scalar(test_data): """Tests regular sub with one scalar input.""" pipeline = TosaPipelineMI[input_t1](Sub(), test_data, aten_op=Sub.aten_op) pipeline.run() -@common.parametrize("test_data", tensor_scalar_tests, xfails=mi_sub_xfails) +@common.parametrize("test_data", tensor_scalar_tests, xfails=xfails) def test_sub_tensor_tosa_MI_inplace(test_data): """Tests inplace sub with one scalar input.""" pipeline = TosaPipelineMI[input_t1](SubInplace(), test_data, aten_op=[]) diff --git a/backends/arm/test/ops/test_sub.py b/backends/arm/test/ops/test_sub.py index e41e589f6a7..5957e27d5a9 100644 --- a/backends/arm/test/ops/test_sub.py +++ b/backends/arm/test/ops/test_sub.py @@ -42,6 +42,8 @@ torch.randn(1, 4, 4, 1), torch.randn(1, 1, 4, 4), ), + "rand_3d_rand_Scalar": lambda: (torch.rand(1, 6, 2), torch.rand(1)), + "rand_3d_Scalar": lambda: (torch.rand(1, 6, 2), 1), } fvp_sub2_xfails = {"rand_4D_2x2x4x4": "MLETORCH-517 : Multiple batches not supported"} @@ -93,7 +95,6 @@ def test_sub_tensor_tosa_BI(test_data): aten_op, exir_op, ) - pipeline.change_args("run_method_and_compare_outputs", qtol=1) pipeline.run() @@ -106,7 +107,6 @@ def test_sub_tensor_tosa_BI_2(test_data: Tuple[torch.Tensor, torch.Tensor]): aten_op, exir_op, ) - pipeline.change_args("run_method_and_compare_outputs", qtol=1) pipeline.run() @@ -121,7 +121,6 @@ def test_sub_tensor_u55_BI(test_data): exir_op, run_on_fvp=True, ) - pipeline.change_args("run_method_and_compare_outputs", qtol=1) pipeline.run() @@ -136,7 +135,6 @@ def test_sub_tensor_u55_BI_2(test_data: Tuple[torch.Tensor, torch.Tensor]): exir_op, run_on_fvp=True, ) - pipeline.change_args("run_method_and_compare_outputs", qtol=1) pipeline.run() @@ -151,7 +149,6 @@ def test_sub_tensor_u85_BI_2(test_data): exir_op, run_on_fvp=True, ) - pipeline.change_args("run_method_and_compare_outputs", qtol=1) pipeline.run() @@ -166,5 +163,4 @@ def test_sub_tensor_u85_BI(test_data: Tuple[torch.Tensor, torch.Tensor]): exir_op, run_on_fvp=True, ) - pipeline.change_args("run_method_and_compare_outputs", qtol=1) pipeline.run()