diff --git a/backends/arm/_passes/__init__.py b/backends/arm/_passes/__init__.py index e4a1526f573..ac7989f1b9b 100644 --- a/backends/arm/_passes/__init__.py +++ b/backends/arm/_passes/__init__.py @@ -8,6 +8,7 @@ from .annotate_channels_last_dim_order_pass import AnnotateChannelsLastDimOrder # noqa from .annotate_decomposed_matmul import AnnotateDecomposedMatmulPass # noqa from .arm_pass import ArmPass # noqa +from .broadcast_args_pass import BroadcastArgsPass # noqa from .cast_int64_pass import CastInt64BuffersToInt32Pass # noqa from .cast_to_int32_pass import CastToInt32Pass # noqa from .conv1d_unsqueeze_pass import Conv1dUnsqueezePass # noqa diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index 5f79757f212..06758e5de14 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -10,6 +10,7 @@ from executorch.backends.arm._passes import ( AnnotateChannelsLastDimOrder, AnnotateDecomposedMatmulPass, + BroadcastArgsPass, CastInt64BuffersToInt32Pass, CastToInt32Pass, ComputeConstantOpsAOT, @@ -104,6 +105,8 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul self.add_pass(RetraceFoldedDtypesPass()) self.add_pass(UnsqueezeScalarPlaceholdersPass(exported_program)) self.add_pass(MatchArgRanksPass(exported_program)) + if self.tosa_spec.is_U55_subset: + self.add_pass(BroadcastArgsPass()) self.add_pass(ComputeConstantOpsAOT(exported_program)) self.add_pass(RemoveClonePass()) diff --git a/backends/arm/_passes/broadcast_args_pass.py b/backends/arm/_passes/broadcast_args_pass.py new file mode 100644 index 00000000000..f125ba13ff4 --- /dev/null +++ b/backends/arm/_passes/broadcast_args_pass.py @@ -0,0 +1,63 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from executorch.backends.arm._passes import ArmPass + +from executorch.backends.arm._passes.arm_pass_utils import ( + create_node, + get_first_fake_tensor, +) + +from executorch.exir.dialects._ops import ops as exir_ops + +from executorch.exir.pass_base import PassResult +from torch.fx import GraphModule, Node + + +class BroadcastArgsPass(ArmPass): + """ + Pass to manually broadcast arguments by inserting repeats. + This is done when more than one arg needs broadcasting. + """ + + targeted_ops = { + exir_ops.edge.aten.add.Tensor, + exir_ops.edge.aten.sub.Tensor, + # mul is indirectly targeting div as div is decompsed to reciprocal + mul + exir_ops.edge.aten.mul.Tensor, + } + + def call(self, graph_module: GraphModule) -> PassResult: + for node in graph_module.graph.nodes: + if node.op != "call_function" or node.target not in self.targeted_ops: + continue + + output_shape = get_first_fake_tensor(node).shape + nbr_of_broacasts = 0 + for arg in node.args: + if not isinstance(arg, Node): + continue + + shape = get_first_fake_tensor(arg).shape + if shape != output_shape: + nbr_of_broacasts += 1 + if nbr_of_broacasts > 1: + multiples = [ + int(output_shape[d] / shape[d]) + for d in range(len(output_shape)) + ] + with graph_module.graph.inserting_before(node): + repeat = create_node( + graph_module.graph, + exir_ops.edge.aten.repeat.default, + args=(arg, multiples), + kwargs={}, + from_node=node, + ) + node.replace_input_with(arg, repeat) + + graph_module.recompile() + graph_module = super().call(graph_module).graph_module + return PassResult(graph_module, True) diff --git a/backends/arm/test/ops/test_add.py b/backends/arm/test/ops/test_add.py index 67833576886..76d4950be6d 100644 --- a/backends/arm/test/ops/test_add.py +++ b/backends/arm/test/ops/test_add.py @@ -60,6 +60,10 @@ def forward(self, x: torch.Tensor, y: torch.Tensor): 10000 * torch.randn(1, 1, 4, 4), torch.randn(1, 1, 4, 1), ), + "4d_randn_1_mutltiple_broadcasts": lambda: ( + torch.randn(1, 4, 4, 1), + torch.ones(1, 1, 4, 4), + ), } diff --git a/backends/arm/test/ops/test_div.py b/backends/arm/test/ops/test_div.py index 087bdb84a63..0e1ca005fa1 100644 --- a/backends/arm/test/ops/test_div.py +++ b/backends/arm/test/ops/test_div.py @@ -66,6 +66,11 @@ torch.rand(5, 10, 25, 20) + 1, None, ), + "op_div_rank4_randn_mutltiple_broadcasts": lambda: ( + torch.randn(1, 4, 4, 1), + torch.randn(1, 1, 4, 4), + None, + ), } diff --git a/backends/arm/test/ops/test_mul.py b/backends/arm/test/ops/test_mul.py index f960f348a87..a4c0dd4a0f8 100644 --- a/backends/arm/test/ops/test_mul.py +++ b/backends/arm/test/ops/test_mul.py @@ -51,6 +51,10 @@ 200 * torch.randn(1, 10, 25, 20), torch.rand(1, 10, 25, 1), ), + "op_mul_rank4_randn_mutltiple_broadcasts": lambda: ( + torch.randn(1, 4, 4, 1), + torch.randn(1, 1, 4, 4), + ), } diff --git a/backends/arm/test/ops/test_sub.py b/backends/arm/test/ops/test_sub.py index f61f3b0583d..e41e589f6a7 100644 --- a/backends/arm/test/ops/test_sub.py +++ b/backends/arm/test/ops/test_sub.py @@ -38,6 +38,10 @@ "rand_3D_4x4x4": lambda: (torch.rand(4, 2, 2), torch.rand(4, 2, 2)), "rand_4D_2x2x4x4": lambda: (torch.rand(2, 2, 4, 4), torch.rand(2, 2, 4, 4)), "zeros": lambda: (torch.rand(4, 4), torch.zeros(4, 4)), + "randn_4D_mutltiple_broadcasts": lambda: ( + torch.randn(1, 4, 4, 1), + torch.randn(1, 1, 4, 4), + ), } fvp_sub2_xfails = {"rand_4D_2x2x4x4": "MLETORCH-517 : Multiple batches not supported"} diff --git a/backends/arm/test/passes/test_broadcast_args_pass.py b/backends/arm/test/passes/test_broadcast_args_pass.py new file mode 100644 index 00000000000..719a0ddd622 --- /dev/null +++ b/backends/arm/test/passes/test_broadcast_args_pass.py @@ -0,0 +1,54 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import operator +from typing import Tuple + +import torch +from executorch.backends.arm._passes import BroadcastArgsPass + +from executorch.backends.arm.test import common +from executorch.backends.arm.test.tester.test_pipeline import PassPipeline + +input_t = Tuple[torch.Tensor] # Input x + + +class NeedsMultipleBroadcastsModel(torch.nn.Module): + test_data = (torch.rand(1, 10), torch.rand(10, 1)) + + def __init__(self, op: operator): + self.op = op + super().__init__() + + def forward(self, x: torch.Tensor, y: torch.Tensor): + return self.op(x, y) + + +modules = { + "add": NeedsMultipleBroadcastsModel(operator.add), + "sub": NeedsMultipleBroadcastsModel(operator.sub), + "mul": NeedsMultipleBroadcastsModel(operator.mul), + "div": NeedsMultipleBroadcastsModel(operator.truediv), +} + + +@common.parametrize("module", modules) +def test_multiple_broacasts_model(module: NeedsMultipleBroadcastsModel): + test_data = module.test_data + ops_not_before_pass = [ + "executorch_exir_dialects_edge__ops_aten_repeat_default", + ] + ops_after_pass = { + "executorch_exir_dialects_edge__ops_aten_repeat_default": 1, + } + pipeline = PassPipeline[input_t]( + module, + test_data, + quantize=True, + ops_not_before_pass=ops_not_before_pass, + ops_after_pass=ops_after_pass, + pass_list=[BroadcastArgsPass], + ) + pipeline.run()