Skip to content

Commit caeb9ff

Browse files
authored
Arm backend: Don't run pass call_operator in submodules (#15736)
When a call_operator pass traces a graph with a submodule, it will use call_operator to trace the nodes in the submodule, and the pass will affect them. This means 1) Passes will operator on submodules both when their containing modules run them and when the submodule itself is handled. 2) This doesn't happen for call passes, so passes won't run in the intended order. All call_operator passes will run before call passes. To get around this, we detect submodules in ArmPass's call_submodule, and use the default call_operator function when tracing them. cc @freddan80 @per @zingo @oscarandersson8218 @digantdesai Signed-off-by: Erik Lundell <[email protected]>
1 parent b7beb37 commit caeb9ff

File tree

2 files changed

+95
-1
lines changed

2 files changed

+95
-1
lines changed

backends/arm/_passes/arm_pass.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,20 @@
66

77
import traceback
88
from abc import abstractmethod
9-
from typing import List, Optional, Set, Type
9+
from typing import Any, List, Optional, Set, Type
1010

1111
from executorch.exir.pass_base import ExportPass, NodeMetadata
12+
from torch.fx import GraphModule
13+
from torch.fx.passes.infra.pass_base import PassResult
1214

1315

1416
class ArmPass(ExportPass):
1517
"""Base class for Arm passes"""
1618

19+
def __init__(self) -> None:
20+
super().__init__()
21+
self.submodule_depth = 0
22+
1723
@property
1824
@abstractmethod
1925
def _passes_required_after(self) -> Set[Type[ExportPass]]:
@@ -56,3 +62,19 @@ def call_operator(self, op, args, kwargs, meta, updated: Optional[bool] = False)
5662
old_stack_trace = new_meta.get("stack_trace", "")
5763
new_meta["stack_trace"] = f"{old_stack_trace}\n{traceback.format_stack()[-2]}"
5864
return super().call_operator(op, args, kwargs, NodeMetadata(new_meta))
65+
66+
def call_submodule(
67+
self, graph_module: GraphModule, inputs: tuple[Any, ...]
68+
) -> PassResult:
69+
self.submodule_depth += 1
70+
if self.submodule_depth == 1:
71+
result = super().call_submodule(graph_module, inputs)
72+
else:
73+
# When we trace a submodule, we don't want to apply the calling pass.
74+
# Temporarily replace call_operator to avoid this.
75+
_call_operator_fn = self.call_operator
76+
self.call_operator = super().call_operator # type: ignore
77+
result = super().call_submodule(graph_module, inputs)
78+
self.call_operator = _call_operator_fn # type: ignore
79+
self.submodule_depth -= 1
80+
return result
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from typing import Any
7+
8+
import torch
9+
10+
from executorch.backends.arm._passes.arm_pass import ArmPass
11+
from executorch.backends.arm._passes.arm_pass_manager import ArmPassManager
12+
from executorch.backends.arm.tosa.specification import TosaSpecification
13+
from torch.fx import GraphModule
14+
from torch.fx.passes.infra.pass_base import PassResult
15+
16+
17+
class _DepthRecordingPass(ArmPass):
18+
_passes_required_after = set()
19+
20+
def __init__(self, initial_graph_module):
21+
super().__init__()
22+
self.depths: list[int] = []
23+
self.initial_submodule = initial_graph_module
24+
self.submodule = None
25+
self.num_submodules_called = 0
26+
27+
def call_operator(self, op, args, kwargs, meta, updated: bool = False):
28+
"""Should only be called from the top-level graph module."""
29+
self.depths.append(self.submodule_depth)
30+
assert self.submodule == self.initial_submodule
31+
return super().call_operator(op, args, kwargs, meta, updated)
32+
33+
def call_submodule(
34+
self, graph_module: GraphModule, inputs: tuple[Any, ...]
35+
) -> PassResult:
36+
"""Should be called for all three graph_modules: top-level, if, and else."""
37+
self.submodule = graph_module
38+
self.num_submodules_called += 1
39+
return super().call_submodule(graph_module, inputs)
40+
41+
42+
class _CondModule(torch.nn.Module):
43+
def forward(self, x: torch.Tensor) -> torch.Tensor:
44+
def _true_branch(arg: torch.Tensor) -> torch.Tensor:
45+
return arg + 1
46+
47+
def _false_branch(arg: torch.Tensor) -> torch.Tensor:
48+
return arg - 1
49+
50+
predicate = x.sum() > 0
51+
return torch.cond(predicate, _true_branch, _false_branch, [x])
52+
53+
54+
def test_call_operator_runs_once_for_cond_submodules() -> None:
55+
module = _CondModule()
56+
example_inputs = (torch.randn(2, 3),)
57+
exported = torch.export.export(module, example_inputs)
58+
graph_module = exported.graph_module
59+
60+
recording_pass = _DepthRecordingPass(graph_module)
61+
pass_manager = ArmPassManager(TosaSpecification.create_from_string("TOSA-1.00+FP"))
62+
pass_manager.add_pass(recording_pass)
63+
pass_manager._transform(graph_module)
64+
65+
assert recording_pass.num_submodules_called == 3
66+
assert recording_pass.depths, "call_operator was never invoked"
67+
assert (
68+
max(recording_pass.depths) == 1
69+
), "call_operator was invoked with larger than one submodule depth."
70+
assert (
71+
min(recording_pass.depths) == 1
72+
), "call_operator was invoked with zero submodule depth."

0 commit comments

Comments
 (0)