Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backends/arm/_passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from .annotate_channels_last_dim_order_pass import AnnotateChannelsLastDimOrder # noqa
from .annotate_decomposed_matmul import AnnotateDecomposedMatmulPass # noqa
from .arm_pass import ArmPass # noqa
from .broadcast_args_pass import BroadcastArgsPass # noqa
from .cast_int64_pass import CastInt64BuffersToInt32Pass # noqa
from .cast_to_int32_pass import CastToInt32Pass # noqa
from .conv1d_unsqueeze_pass import Conv1dUnsqueezePass # noqa
Expand Down
3 changes: 3 additions & 0 deletions backends/arm/_passes/arm_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from executorch.backends.arm._passes import (
AnnotateChannelsLastDimOrder,
AnnotateDecomposedMatmulPass,
BroadcastArgsPass,
CastInt64BuffersToInt32Pass,
CastToInt32Pass,
ComputeConstantOpsAOT,
Expand Down Expand Up @@ -104,6 +105,8 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
self.add_pass(RetraceFoldedDtypesPass())
self.add_pass(UnsqueezeScalarPlaceholdersPass(exported_program))
self.add_pass(MatchArgRanksPass(exported_program))
if self.tosa_spec.is_U55_subset:
self.add_pass(BroadcastArgsPass())
self.add_pass(ComputeConstantOpsAOT(exported_program))

self.add_pass(RemoveClonePass())
Expand Down
63 changes: 63 additions & 0 deletions backends/arm/_passes/broadcast_args_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from executorch.backends.arm._passes import ArmPass

from executorch.backends.arm._passes.arm_pass_utils import (
create_node,
get_first_fake_tensor,
)

from executorch.exir.dialects._ops import ops as exir_ops

from executorch.exir.pass_base import PassResult
from torch.fx import GraphModule, Node


class BroadcastArgsPass(ArmPass):
"""
Pass to manually broadcast arguments by inserting repeats.
This is done when more than one arg needs broadcasting.
"""

targeted_ops = {
exir_ops.edge.aten.add.Tensor,
exir_ops.edge.aten.sub.Tensor,
# mul is indirectly targeting div as div is decompsed to reciprocal + mul
exir_ops.edge.aten.mul.Tensor,
}

def call(self, graph_module: GraphModule) -> PassResult:
for node in graph_module.graph.nodes:
if node.op != "call_function" or node.target not in self.targeted_ops:
continue

output_shape = get_first_fake_tensor(node).shape
nbr_of_broacasts = 0
for arg in node.args:
if not isinstance(arg, Node):
continue

shape = get_first_fake_tensor(arg).shape
if shape != output_shape:
nbr_of_broacasts += 1
if nbr_of_broacasts > 1:
multiples = [
int(output_shape[d] / shape[d])
for d in range(len(output_shape))
]
with graph_module.graph.inserting_before(node):
repeat = create_node(
graph_module.graph,
exir_ops.edge.aten.repeat.default,
args=(arg, multiples),
kwargs={},
from_node=node,
)
node.replace_input_with(arg, repeat)

graph_module.recompile()
graph_module = super().call(graph_module).graph_module
return PassResult(graph_module, True)
4 changes: 4 additions & 0 deletions backends/arm/test/ops/test_add.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,10 @@ def forward(self, x: torch.Tensor, y: torch.Tensor):
10000 * torch.randn(1, 1, 4, 4),
torch.randn(1, 1, 4, 1),
),
"4d_randn_1_mutltiple_broadcasts": lambda: (
torch.randn(1, 4, 4, 1),
torch.ones(1, 1, 4, 4),
),
}


Expand Down
5 changes: 5 additions & 0 deletions backends/arm/test/ops/test_div.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,11 @@
torch.rand(5, 10, 25, 20) + 1,
None,
),
"op_div_rank4_randn_mutltiple_broadcasts": lambda: (
torch.randn(1, 4, 4, 1),
torch.randn(1, 1, 4, 4),
None,
),
}


Expand Down
4 changes: 4 additions & 0 deletions backends/arm/test/ops/test_mul.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,10 @@
200 * torch.randn(1, 10, 25, 20),
torch.rand(1, 10, 25, 1),
),
"op_mul_rank4_randn_mutltiple_broadcasts": lambda: (
torch.randn(1, 4, 4, 1),
torch.randn(1, 1, 4, 4),
),
}


Expand Down
4 changes: 4 additions & 0 deletions backends/arm/test/ops/test_sub.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@
"rand_3D_4x4x4": lambda: (torch.rand(4, 2, 2), torch.rand(4, 2, 2)),
"rand_4D_2x2x4x4": lambda: (torch.rand(2, 2, 4, 4), torch.rand(2, 2, 4, 4)),
"zeros": lambda: (torch.rand(4, 4), torch.zeros(4, 4)),
"randn_4D_mutltiple_broadcasts": lambda: (
torch.randn(1, 4, 4, 1),
torch.randn(1, 1, 4, 4),
),
}
fvp_sub2_xfails = {"rand_4D_2x2x4x4": "MLETORCH-517 : Multiple batches not supported"}

Expand Down
54 changes: 54 additions & 0 deletions backends/arm/test/passes/test_broadcast_args_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import operator
from typing import Tuple

import torch
from executorch.backends.arm._passes import BroadcastArgsPass

from executorch.backends.arm.test import common
from executorch.backends.arm.test.tester.test_pipeline import PassPipeline

input_t = Tuple[torch.Tensor] # Input x


class NeedsMultipleBroadcastsModel(torch.nn.Module):
test_data = (torch.rand(1, 10), torch.rand(10, 1))

def __init__(self, op: operator):
self.op = op
super().__init__()

def forward(self, x: torch.Tensor, y: torch.Tensor):
return self.op(x, y)


modules = {
"add": NeedsMultipleBroadcastsModel(operator.add),
"sub": NeedsMultipleBroadcastsModel(operator.sub),
"mul": NeedsMultipleBroadcastsModel(operator.mul),
"div": NeedsMultipleBroadcastsModel(operator.truediv),
}


@common.parametrize("module", modules)
def test_multiple_broacasts_model(module: NeedsMultipleBroadcastsModel):
test_data = module.test_data
ops_not_before_pass = [
"executorch_exir_dialects_edge__ops_aten_repeat_default",
]
ops_after_pass = {
"executorch_exir_dialects_edge__ops_aten_repeat_default": 1,
}
pipeline = PassPipeline[input_t](
module,
test_data,
quantize=True,
ops_not_before_pass=ops_not_before_pass,
ops_after_pass=ops_after_pass,
pass_list=[BroadcastArgsPass],
)
pipeline.run()
Loading