Skip to content

Commit 6ee2bbe

Browse files
committed
Recurse partitioner through branches
1 parent 73591f1 commit 6ee2bbe

File tree

1 file changed

+15
-7
lines changed

1 file changed

+15
-7
lines changed

exir/backend/canonical_partitioners/config_partitioner.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@
1717
Partitioner,
1818
PartitionResult,
1919
)
20+
from torch.fx import GraphModule
2021
from torch.fx.passes.infra.partitioner import Partition
22+
from executorch.exir.graph_module import get_control_flow_submodules
2123

2224

2325
def format_target_name(target_name: str) -> str:
@@ -160,11 +162,11 @@ def filter_fn(node: torch.fx.Node) -> bool:
160162
return (do_not_decomp, filter_fn)
161163

162164
def get_matched_nodes_from_configs(
163-
self, ep: ExportedProgram
165+
self, ep: ExportedProgram, gm: Optional[GraphModule] = None
164166
) -> List[List[torch.fx.Node]]:
165167
# gather supported nodes
166168
matched_nodes = []
167-
gm = ep.graph_module
169+
gm = gm or ep.graph_module
168170
for node in gm.graph.nodes:
169171
if node.op == "call_function":
170172
target = format_target_name(node.target.__name__)
@@ -175,17 +177,19 @@ def get_matched_nodes_from_configs(
175177

176178
return matched_nodes
177179

178-
def generate_partitions(self, ep: ExportedProgram) -> List[Partition]:
179-
matched_nodes = self.get_matched_nodes_from_configs(ep)
180+
def generate_partitions(self, ep: ExportedProgram, gm: Optional[GraphModule] = None) -> List[Partition]:
181+
gm = gm or ep.graph_module
182+
matched_nodes = self.get_matched_nodes_from_configs(ep, gm)
180183
# create partitions
181184
partitions = generate_partitions_from_list_of_nodes(
182-
ep.graph_module,
185+
gm,
183186
matched_nodes,
184187
)
185188
return partitions
186189

187-
def partition(self, exported_program: ExportedProgram) -> PartitionResult:
188-
partitions = self.generate_partitions(exported_program)
190+
def partition(self, exported_program: ExportedProgram, graph_module: Optional[GraphModule] = None) -> PartitionResult:
191+
graph_module = graph_module or exported_program.graph_module
192+
partitions = self.generate_partitions(exported_program, graph_module)
189193

190194
# tag nodes
191195
partition_tags: Dict[str, DelegationSpec] = {}
@@ -199,6 +203,10 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult:
199203
node.meta["delegation_tag"] = delegation_tag
200204
partition_tags[delegation_tag] = self.delegation_spec
201205

206+
for _, submodule, _ in get_control_flow_submodules(graph_module):
207+
# pyre-ignore
208+
self.partition(exported_program, submodule)
209+
202210
return PartitionResult(
203211
tagged_exported_program=exported_program, partition_tags=partition_tags
204212
)

0 commit comments

Comments
 (0)