Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backends/arm/_passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@
)
from .fuse_batchnorm2d_pass import FuseBatchnorm2DPass # noqa
from .fuse_constant_ops_pass import ComputeConstantOpsAOT, FuseConstantArgsPass # noqa
from .fuse_duplicate_users_pass import FuseDuplicateUsersPass # noqa
from .fuse_equal_placeholders_pass import FuseEqualPlaceholdersPass # noqa
from .fuse_quantized_activation_pass import FuseQuantizedActivationPass # noqa
from .insert_int32_casts_after_int64_placeholders import ( # noqa
Expand Down
4 changes: 4 additions & 0 deletions backends/arm/_passes/arm_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@
FoldAndAnnotateQParamsPass,
FuseBatchnorm2DPass,
FuseConstantArgsPass,
FuseDuplicateUsersPass,
FuseEqualPlaceholdersPass,
FuseQuantizedActivationPass,
InsertInt32CastsAfterInt64PlaceholdersPass,
Expand Down Expand Up @@ -175,6 +176,7 @@ def _tosa_INT_pipeline(
self.add_pass(QuantizeOperatorArguments())
self.add_pass(ConvertELUParamsPass())
self.add_pass(FoldAndAnnotateQParamsPass(exported_program)) # type: ignore[call-arg]
self.add_pass(FuseDuplicateUsersPass())
self.add_pass(UnsqueezeScalarPlaceholdersPass(exported_program))
self.add_pass(MatchArgRanksPass(exported_program))
if self.tosa_spec.is_U55_subset:
Expand Down Expand Up @@ -209,6 +211,7 @@ def _tosa_INT_pipeline(
self.add_pass(RewriteMatmulPass())
self.add_pass(RewriteUpsamplePass())
self.add_pass(FuseEqualPlaceholdersPass(exported_program))

self.add_pass(InsertRescaleInt32Pass())
self.add_pass(DecomposeSumPass())
self.add_pass(ToTosaMemoryFormatPass(exported_program))
Expand All @@ -222,6 +225,7 @@ def _tosa_FP_pipeline(
self, exported_program: ExportedProgram, graph_module: GraphModule
) -> GraphModule:
self.add_pass(AnnotateOutputDimOrderPass())
self.add_pass(FuseDuplicateUsersPass())
self.add_pass(DecomposeExpm1Pass())
self.add_pass(DecomposeLogitPass())
self.add_pass(DecomposeMaskedFill())
Expand Down
165 changes: 165 additions & 0 deletions backends/arm/_passes/fuse_duplicate_users_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from collections import deque
from typing import Any, Deque, Dict, Hashable, List, Set, Tuple, Type

import torch
from executorch.backends.arm._passes.arm_pass import ArmPass
from executorch.exir.dialects.edge._ops import EdgeOpOverload
from executorch.exir.pass_base import ExportPass, PassResult
from torch._ops import OpOverload
from torch.fx import GraphModule, Node
from torch.fx.node import Argument, map_arg


class FuseDuplicateUsersPass(ArmPass):
"""Fuse identical users of a producer node into a single operation.

Example:

y = producer(x)
z0 = torch.add(y, bias)
z1 = torch.add(y, bias)

becomes a single ``torch.add`` that feeds both consumers.
"""

_passes_required_after: Set[Type[ExportPass]] = set()

def call(self, graph_module: GraphModule) -> PassResult:
graph = graph_module.graph
modified = False

producers: Deque[Node] = deque(node for node in graph.nodes)

while producers:
producer = producers.popleft()

if producer.graph is None:
# Node was deleted by a previous rewrite while still queued.
continue

# Only meaningful if a value is consumed by multiple users.
user_nodes = list(producer.users)
if len(user_nodes) < 2:
continue

candidate_groups = self._get_candidate_groups(user_nodes)

signature_to_user: Dict[Tuple[Hashable, ...], Node] = {}
for group in candidate_groups:
for user in group:
signature = self._build_user_signature(user)
if signature is None:
continue

representative = signature_to_user.get(signature)
if representative is None:
# Check if we already encountered identical node that we can fuse with.
signature_to_user[signature] = user
continue

if user is representative:
# The queue can enqueue the surviving node again after rewrites.
continue

user.replace_all_uses_with(representative)
graph.erase_node(user)
modified = True

# Revisit the current producer and the surviving user so that
# newly formed duplicate chains can be fused in later
# iterations.
producers.append(producer)
producers.append(representative)

if modified:
graph_module.recompile()
graph_module.graph.lint()
graph_module = super().call(graph_module).graph_module

return PassResult(graph_module, modified)

def _get_candidate_groups(self, user_nodes):
users_by_target: Dict[Tuple[str, Hashable], List[Node]] = {}
for user in user_nodes:
if user.graph is None:
# User might already have been removed by a prior rewrite.
continue

if user.op != "call_function":
continue

target_key = self._get_target_key(user.target)
target_signature = (user.op, target_key)
users_by_target.setdefault(target_signature, []).append(user)

candidate_groups = [
group for group in users_by_target.values() if len(group) > 1
]

return candidate_groups

def _build_user_signature(self, node: Node) -> Tuple[Hashable, ...] | None:
try:
normalized_args = self._to_hashable(
map_arg(node.args, self._map_leaf_to_key)
)
normalized_kwargs = self._to_hashable(
{k: map_arg(v, self._map_leaf_to_key) for k, v in node.kwargs.items()}
)
except TypeError:
return None

target_key = self._get_target_key(node.target)

return (node.op, target_key, normalized_args, normalized_kwargs)

def _map_leaf_to_key(self, node: Node) -> Argument:
return node.name

def _to_hashable(self, value: Any) -> Hashable:
"""Convert arbitrarily nested structures into hashable tuples."""

if isinstance(value, (list, tuple)):
return tuple(self._to_hashable(v) for v in value)
if isinstance(value, dict):
normalized_items = [(k, self._to_hashable(v)) for k, v in value.items()]
return tuple(sorted(normalized_items, key=lambda item: repr(item[0])))
if isinstance(value, set):
hashable_values: List[Hashable] = [self._to_hashable(v) for v in value]
return tuple(sorted(hashable_values, key=repr))
if isinstance(value, slice):
return (
"slice",
self._to_hashable(value.start),
self._to_hashable(value.stop),
self._to_hashable(value.step),
)
if isinstance(value, range):
return ("range", value.start, value.stop, value.step)
if isinstance(value, torch.Size):
return ("size", tuple(value))
if isinstance(value, torch.dtype):
return ("dtype", str(value))
if isinstance(value, torch.device):
return ("device", str(value))
if isinstance(value, torch.memory_format):
return ("memory_format", str(value))
if isinstance(value, torch.Tensor):
return (
"tensor",
str(value.dtype),
tuple(value.size()),
value.device.type,
value.requires_grad,
)
return value

def _get_target_key(self, target: Any) -> Hashable:
if isinstance(target, (EdgeOpOverload, OpOverload)):
return str(target)
return target
32 changes: 28 additions & 4 deletions backends/arm/test/ops/test_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,13 @@ def test_matmul_u55_INT(test_data: input_t1):
pipeline.run()


@common.parametrize("test_data", MatMulSingleInput.test_data_generators)
@common.parametrize(
"test_data",
MatMulSingleInput.test_data_generators,
xfails={
"rand_4d": "MLBEDSW-11228: Matmul output diff between 1 input vs 2 identical inputs"
},
)
@common.XfailIfNoCorstone300
def test_matmul_single_input_u55_INT(test_data: input_t1):
pipeline = EthosU55PipelineINT[input_t1](
Expand All @@ -147,7 +153,13 @@ def test_matmul_single_input_u55_INT(test_data: input_t1):
pipeline.run()


@common.parametrize("test_data", MatMulCombo.test_data_generators)
@common.parametrize(
"test_data",
MatMulCombo.test_data_generators,
xfails={
"rand_rand_rand_4d": "MLBEDSW-11228: Matmul output diff between 1 input vs 2 identical inputs"
},
)
@common.XfailIfNoCorstone300
def test_matmul_combo_u55_INT(test_data: input_t1):
pipeline = EthosU55PipelineINT[input_t1](
Expand All @@ -173,7 +185,13 @@ def test_matmul_u85_INT(test_data: input_t1):
pipeline.run()


@common.parametrize("test_data", MatMulSingleInput.test_data_generators)
@common.parametrize(
"test_data",
MatMulSingleInput.test_data_generators,
xfails={
"rand_4d": "MLBEDSW-11228: Matmul output diff between 1 input vs 2 identical inputs"
},
)
@common.XfailIfNoCorstone320
def test_matmul_single_input_u85_INT(test_data: input_t1):
pipeline = EthosU85PipelineINT[input_t1](
Expand All @@ -186,7 +204,13 @@ def test_matmul_single_input_u85_INT(test_data: input_t1):
pipeline.run()


@common.parametrize("test_data", MatMulCombo.test_data_generators)
@common.parametrize(
"test_data",
MatMulCombo.test_data_generators,
xfails={
"rand_rand_rand_4d": "MLBEDSW-11228: Matmul output diff between 1 input vs 2 identical inputs"
},
)
@common.XfailIfNoCorstone320
def test_matmul_combo_u85_INT(test_data: input_t1):
pipeline = EthosU85PipelineINT[input_t1](
Expand Down
65 changes: 65 additions & 0 deletions backends/arm/test/passes/test_fuse_duplicate_users_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Tuple

import torch
from executorch.backends.arm._passes import FuseDuplicateUsersPass
from executorch.backends.arm.test import common
from executorch.backends.arm.test.tester.test_pipeline import PassPipeline

input_t = Tuple[torch.Tensor] # Input x


class FuseaAvgPool(torch.nn.Module):
ops_before_pass = {
"executorch_exir_dialects_edge__ops_aten_avg_pool2d_default": 3,
}
ops_after_pass = {"executorch_exir_dialects_edge__ops_aten_avg_pool2d_default": 1}

def __init__(self):
super().__init__()
self.avg = torch.nn.AvgPool2d(1)

def forward(self, x):
return self.avg(x) + self.avg(x) + self.avg(x)


class FuseAvgPoolChain(torch.nn.Module):
ops_before_pass = {
"executorch_exir_dialects_edge__ops_aten_avg_pool2d_default": 6,
}
ops_after_pass = {"executorch_exir_dialects_edge__ops_aten_avg_pool2d_default": 2}

def __init__(self):
super().__init__()
self.avg = torch.nn.AvgPool2d(1)

def forward(self, x):
first = self.avg(self.avg(x))
second = self.avg(self.avg(x))
third = self.avg(self.avg(x))
return first + second + third


modules = {
"fuse_avg_pool": FuseaAvgPool(),
"fuse_avg_pool_chain": FuseAvgPoolChain(),
}


@common.parametrize("module", modules)
def test_fuse_duplicate_ops_FP(module: torch.nn.Module):
pipeline = PassPipeline[input_t](
module=module,
test_data=(torch.ones(1, 1, 1, 1),),
quantize=False,
ops_before_pass=module.ops_before_pass,
ops_after_pass=module.ops_after_pass,
pass_list=[
FuseDuplicateUsersPass,
],
)
pipeline.run()
Loading