From b0027b79f1e6d6d7e31e4c541dfc9d3136c72814 Mon Sep 17 00:00:00 2001 From: Adrian Lundell Date: Wed, 8 Oct 2025 14:05:32 +0200 Subject: [PATCH] Arm backend: Fuse duplicate user ops GEN AI USED - BLACKDUCK APPROVED Adds a pass which checks if a node has multiple users performing equivalent operations its output. If that is the case, it fuses these ops into one. Change-Id: I10f698429af8e2e7f1130c391d7e9e94616fa5e4 Signed-off-by: Adrian Lundell --- backends/arm/_passes/__init__.py | 1 + backends/arm/_passes/arm_pass_manager.py | 4 + .../arm/_passes/fuse_duplicate_users_pass.py | 165 ++++++++++++++++++ backends/arm/test/ops/test_matmul.py | 32 +++- .../passes/test_fuse_duplicate_users_pass.py | 65 +++++++ 5 files changed, 263 insertions(+), 4 deletions(-) create mode 100644 backends/arm/_passes/fuse_duplicate_users_pass.py create mode 100644 backends/arm/test/passes/test_fuse_duplicate_users_pass.py diff --git a/backends/arm/_passes/__init__.py b/backends/arm/_passes/__init__.py index b1337c38a58..155765ddcb5 100644 --- a/backends/arm/_passes/__init__.py +++ b/backends/arm/_passes/__init__.py @@ -77,6 +77,7 @@ ) from .fuse_batchnorm2d_pass import FuseBatchnorm2DPass # noqa from .fuse_constant_ops_pass import ComputeConstantOpsAOT, FuseConstantArgsPass # noqa +from .fuse_duplicate_users_pass import FuseDuplicateUsersPass # noqa from .fuse_equal_placeholders_pass import FuseEqualPlaceholdersPass # noqa from .fuse_quantized_activation_pass import FuseQuantizedActivationPass # noqa from .insert_int32_casts_after_int64_placeholders import ( # noqa diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index d0d3aae148f..d0c81e1938c 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -79,6 +79,7 @@ FoldAndAnnotateQParamsPass, FuseBatchnorm2DPass, FuseConstantArgsPass, + FuseDuplicateUsersPass, FuseEqualPlaceholdersPass, FuseQuantizedActivationPass, InsertInt32CastsAfterInt64PlaceholdersPass, @@ -180,6 +181,7 @@ def _tosa_INT_pipeline(self, exported_program: ExportedProgram) -> GraphModule: self.add_pass(ConvertELUParamsPass()) self.add_pass(FoldAndAnnotateQParamsPass(exported_program)) # type: ignore[call-arg] self.add_pass(RetraceFoldedDtypesPass()) + self.add_pass(FuseDuplicateUsersPass()) self.add_pass(UnsqueezeScalarPlaceholdersPass(exported_program)) self.add_pass(MatchArgRanksPass(exported_program)) if self.tosa_spec.is_U55_subset: @@ -215,6 +217,7 @@ def _tosa_INT_pipeline(self, exported_program: ExportedProgram) -> GraphModule: self.add_pass(RewriteMatmulPass()) self.add_pass(FuseEqualPlaceholdersPass(exported_program)) + self.add_pass(ToTosaMemoryFormatPass(exported_program)) self.add_pass(RemoveNoopPass()) self.add_pass(InsertRescalePass()) @@ -225,6 +228,7 @@ def _tosa_INT_pipeline(self, exported_program: ExportedProgram) -> GraphModule: def _tosa_FP_pipeline(self, exported_program: ExportedProgram) -> GraphModule: self.add_pass(AnnotateOutputDimOrderPass()) + self.add_pass(FuseDuplicateUsersPass()) self.add_pass(DecomposeExpm1Pass()) self.add_pass(DecomposeLogitPass()) self.add_pass(DecomposeMaskedFill()) diff --git a/backends/arm/_passes/fuse_duplicate_users_pass.py b/backends/arm/_passes/fuse_duplicate_users_pass.py new file mode 100644 index 00000000000..217d93373f8 --- /dev/null +++ b/backends/arm/_passes/fuse_duplicate_users_pass.py @@ -0,0 +1,165 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from collections import deque +from typing import Any, Deque, Dict, Hashable, List, Set, Tuple, Type + +import torch +from executorch.backends.arm._passes.arm_pass import ArmPass +from executorch.exir.dialects.edge._ops import EdgeOpOverload +from executorch.exir.pass_base import ExportPass, PassResult +from torch._ops import OpOverload +from torch.fx import GraphModule, Node +from torch.fx.node import Argument, map_arg + + +class FuseDuplicateUsersPass(ArmPass): + """Fuse identical users of a producer node into a single operation. + + Example: + + y = producer(x) + z0 = torch.add(y, bias) + z1 = torch.add(y, bias) + + becomes a single ``torch.add`` that feeds both consumers. + """ + + _passes_required_after: Set[Type[ExportPass]] = set() + + def call(self, graph_module: GraphModule) -> PassResult: + graph = graph_module.graph + modified = False + + producers: Deque[Node] = deque(node for node in graph.nodes) + + while producers: + producer = producers.popleft() + + if producer.graph is None: + # Node was deleted by a previous rewrite while still queued. + continue + + # Only meaningful if a value is consumed by multiple users. + user_nodes = list(producer.users) + if len(user_nodes) < 2: + continue + + candidate_groups = self._get_candidate_groups(user_nodes) + + signature_to_user: Dict[Tuple[Hashable, ...], Node] = {} + for group in candidate_groups: + for user in group: + signature = self._build_user_signature(user) + if signature is None: + continue + + representative = signature_to_user.get(signature) + if representative is None: + # Check if we already encountered identical node that we can fuse with. + signature_to_user[signature] = user + continue + + if user is representative: + # The queue can enqueue the surviving node again after rewrites. + continue + + user.replace_all_uses_with(representative) + graph.erase_node(user) + modified = True + + # Revisit the current producer and the surviving user so that + # newly formed duplicate chains can be fused in later + # iterations. + producers.append(producer) + producers.append(representative) + + if modified: + graph_module.recompile() + graph_module.graph.lint() + graph_module = super().call(graph_module).graph_module + + return PassResult(graph_module, modified) + + def _get_candidate_groups(self, user_nodes): + users_by_target: Dict[Tuple[str, Hashable], List[Node]] = {} + for user in user_nodes: + if user.graph is None: + # User might already have been removed by a prior rewrite. + continue + + if user.op != "call_function": + continue + + target_key = self._get_target_key(user.target) + target_signature = (user.op, target_key) + users_by_target.setdefault(target_signature, []).append(user) + + candidate_groups = [ + group for group in users_by_target.values() if len(group) > 1 + ] + + return candidate_groups + + def _build_user_signature(self, node: Node) -> Tuple[Hashable, ...] | None: + try: + normalized_args = self._to_hashable( + map_arg(node.args, self._map_leaf_to_key) + ) + normalized_kwargs = self._to_hashable( + {k: map_arg(v, self._map_leaf_to_key) for k, v in node.kwargs.items()} + ) + except TypeError: + return None + + target_key = self._get_target_key(node.target) + + return (node.op, target_key, normalized_args, normalized_kwargs) + + def _map_leaf_to_key(self, node: Node) -> Argument: + return node.name + + def _to_hashable(self, value: Any) -> Hashable: + """Convert arbitrarily nested structures into hashable tuples.""" + + if isinstance(value, (list, tuple)): + return tuple(self._to_hashable(v) for v in value) + if isinstance(value, dict): + normalized_items = [(k, self._to_hashable(v)) for k, v in value.items()] + return tuple(sorted(normalized_items, key=lambda item: repr(item[0]))) + if isinstance(value, set): + hashable_values: List[Hashable] = [self._to_hashable(v) for v in value] + return tuple(sorted(hashable_values, key=repr)) + if isinstance(value, slice): + return ( + "slice", + self._to_hashable(value.start), + self._to_hashable(value.stop), + self._to_hashable(value.step), + ) + if isinstance(value, range): + return ("range", value.start, value.stop, value.step) + if isinstance(value, torch.Size): + return ("size", tuple(value)) + if isinstance(value, torch.dtype): + return ("dtype", str(value)) + if isinstance(value, torch.device): + return ("device", str(value)) + if isinstance(value, torch.memory_format): + return ("memory_format", str(value)) + if isinstance(value, torch.Tensor): + return ( + "tensor", + str(value.dtype), + tuple(value.size()), + value.device.type, + value.requires_grad, + ) + return value + + def _get_target_key(self, target: Any) -> Hashable: + if isinstance(target, (EdgeOpOverload, OpOverload)): + return str(target) + return target diff --git a/backends/arm/test/ops/test_matmul.py b/backends/arm/test/ops/test_matmul.py index f564672e98f..0baf609ce45 100644 --- a/backends/arm/test/ops/test_matmul.py +++ b/backends/arm/test/ops/test_matmul.py @@ -134,7 +134,13 @@ def test_matmul_u55_INT(test_data: input_t1): pipeline.run() -@common.parametrize("test_data", MatMulSingleInput.test_data_generators) +@common.parametrize( + "test_data", + MatMulSingleInput.test_data_generators, + xfails={ + "rand_4d": "MLBEDSW-11228: Matmul output diff between 1 input vs 2 identical inputs" + }, +) @common.XfailIfNoCorstone300 def test_matmul_single_input_u55_INT(test_data: input_t1): pipeline = EthosU55PipelineINT[input_t1]( @@ -147,7 +153,13 @@ def test_matmul_single_input_u55_INT(test_data: input_t1): pipeline.run() -@common.parametrize("test_data", MatMulCombo.test_data_generators) +@common.parametrize( + "test_data", + MatMulCombo.test_data_generators, + xfails={ + "rand_rand_rand_4d": "MLBEDSW-11228: Matmul output diff between 1 input vs 2 identical inputs" + }, +) @common.XfailIfNoCorstone300 def test_matmul_combo_u55_INT(test_data: input_t1): pipeline = EthosU55PipelineINT[input_t1]( @@ -173,7 +185,13 @@ def test_matmul_u85_INT(test_data: input_t1): pipeline.run() -@common.parametrize("test_data", MatMulSingleInput.test_data_generators) +@common.parametrize( + "test_data", + MatMulSingleInput.test_data_generators, + xfails={ + "rand_4d": "MLBEDSW-11228: Matmul output diff between 1 input vs 2 identical inputs" + }, +) @common.XfailIfNoCorstone320 def test_matmul_single_input_u85_INT(test_data: input_t1): pipeline = EthosU85PipelineINT[input_t1]( @@ -186,7 +204,13 @@ def test_matmul_single_input_u85_INT(test_data: input_t1): pipeline.run() -@common.parametrize("test_data", MatMulCombo.test_data_generators) +@common.parametrize( + "test_data", + MatMulCombo.test_data_generators, + xfails={ + "rand_rand_rand_4d": "MLBEDSW-11228: Matmul output diff between 1 input vs 2 identical inputs" + }, +) @common.XfailIfNoCorstone320 def test_matmul_combo_u85_INT(test_data: input_t1): pipeline = EthosU85PipelineINT[input_t1]( diff --git a/backends/arm/test/passes/test_fuse_duplicate_users_pass.py b/backends/arm/test/passes/test_fuse_duplicate_users_pass.py new file mode 100644 index 00000000000..a7e80794015 --- /dev/null +++ b/backends/arm/test/passes/test_fuse_duplicate_users_pass.py @@ -0,0 +1,65 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Tuple + +import torch +from executorch.backends.arm._passes import FuseDuplicateUsersPass +from executorch.backends.arm.test import common +from executorch.backends.arm.test.tester.test_pipeline import PassPipeline + +input_t = Tuple[torch.Tensor] # Input x + + +class FuseaAvgPool(torch.nn.Module): + ops_before_pass = { + "executorch_exir_dialects_edge__ops_aten_avg_pool2d_default": 3, + } + ops_after_pass = {"executorch_exir_dialects_edge__ops_aten_avg_pool2d_default": 1} + + def __init__(self): + super().__init__() + self.avg = torch.nn.AvgPool2d(1) + + def forward(self, x): + return self.avg(x) + self.avg(x) + self.avg(x) + + +class FuseAvgPoolChain(torch.nn.Module): + ops_before_pass = { + "executorch_exir_dialects_edge__ops_aten_avg_pool2d_default": 6, + } + ops_after_pass = {"executorch_exir_dialects_edge__ops_aten_avg_pool2d_default": 2} + + def __init__(self): + super().__init__() + self.avg = torch.nn.AvgPool2d(1) + + def forward(self, x): + first = self.avg(self.avg(x)) + second = self.avg(self.avg(x)) + third = self.avg(self.avg(x)) + return first + second + third + + +modules = { + "fuse_avg_pool": FuseaAvgPool(), + "fuse_avg_pool_chain": FuseAvgPoolChain(), +} + + +@common.parametrize("module", modules) +def test_fuse_duplicate_ops_FP(module: torch.nn.Module): + pipeline = PassPipeline[input_t]( + module=module, + test_data=(torch.ones(1, 1, 1, 1),), + quantize=False, + ops_before_pass=module.ops_before_pass, + ops_after_pass=module.ops_after_pass, + pass_list=[ + FuseDuplicateUsersPass, + ], + ) + pipeline.run()