Skip to content

Commit ed5b952

Browse files
committed
Recurse partitioner through branches
1 parent 73591f1 commit ed5b952

File tree

2 files changed

+21
-12
lines changed

2 files changed

+21
-12
lines changed

backends/xnnpack/partition/xnnpack_partitioner.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
ConfigerationBasedPartitioner,
2222
)
2323
from executorch.exir.backend.partitioner import DelegationSpec
24+
from torch.fx import GraphModule
2425
from torch.fx.passes.infra.partitioner import Partition
2526

2627
logging.basicConfig(level=logging.WARNING)
@@ -65,25 +66,25 @@ def __init__(
6566
self.per_op_mode = per_op_mode
6667
super().__init__(delegation_spec, initialized_configs)
6768

68-
def generate_partitions(self, ep: ExportedProgram) -> List[Partition]:
69+
def generate_partitions(self, ep: ExportedProgram, gm: Optional[GraphModule] = None) -> List[Partition]:
6970
"""
7071
generate_partitions is different if partitioner is set to per_op_mode
7172
for per_op_mode we only need to generate unmerged partitions instead
7273
of using the default generate_partitions method.
7374
"""
7475
if self.per_op_mode:
75-
return self.generate_per_op_partitions(ep)
76+
return self.generate_per_op_partitions(ep, gm)
7677
else:
77-
return super().generate_partitions(ep)
78+
return super().generate_partitions(ep, gm)
7879

79-
def generate_per_op_partitions(self, ep: ExportedProgram) -> List[Partition]:
80+
def generate_per_op_partitions(self, ep: ExportedProgram, gm: Optional[GraphModule] = None) -> List[Partition]:
8081
"""
8182
Uses configs to generate per_op_partitions. That is no partitions are
8283
merged together. All partitions (node + deps) returned by PartitionerConfigs
8384
are put into their own partition.
8485
"""
8586
partitions = []
86-
matched_nodes = self.get_matched_nodes_from_configs(ep)
87+
matched_nodes = self.get_matched_nodes_from_configs(ep, gm)
8788
partition_id = itertools.count()
8889
nodes_seen = {}
8990
for match in matched_nodes:

exir/backend/canonical_partitioners/config_partitioner.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@
1717
Partitioner,
1818
PartitionResult,
1919
)
20+
from torch.fx import GraphModule
2021
from torch.fx.passes.infra.partitioner import Partition
22+
from executorch.exir.graph_module import get_control_flow_submodules
2123

2224

2325
def format_target_name(target_name: str) -> str:
@@ -160,11 +162,11 @@ def filter_fn(node: torch.fx.Node) -> bool:
160162
return (do_not_decomp, filter_fn)
161163

162164
def get_matched_nodes_from_configs(
163-
self, ep: ExportedProgram
165+
self, ep: ExportedProgram, gm: Optional[GraphModule] = None
164166
) -> List[List[torch.fx.Node]]:
165167
# gather supported nodes
166168
matched_nodes = []
167-
gm = ep.graph_module
169+
gm = gm or ep.graph_module
168170
for node in gm.graph.nodes:
169171
if node.op == "call_function":
170172
target = format_target_name(node.target.__name__)
@@ -175,17 +177,19 @@ def get_matched_nodes_from_configs(
175177

176178
return matched_nodes
177179

178-
def generate_partitions(self, ep: ExportedProgram) -> List[Partition]:
179-
matched_nodes = self.get_matched_nodes_from_configs(ep)
180+
def generate_partitions(self, ep: ExportedProgram, gm: Optional[GraphModule] = None) -> List[Partition]:
181+
gm = gm or ep.graph_module
182+
matched_nodes = self.get_matched_nodes_from_configs(ep, gm)
180183
# create partitions
181184
partitions = generate_partitions_from_list_of_nodes(
182-
ep.graph_module,
185+
gm,
183186
matched_nodes,
184187
)
185188
return partitions
186189

187-
def partition(self, exported_program: ExportedProgram) -> PartitionResult:
188-
partitions = self.generate_partitions(exported_program)
190+
def partition(self, exported_program: ExportedProgram, graph_module: Optional[GraphModule] = None) -> PartitionResult:
191+
graph_module = graph_module or exported_program.graph_module
192+
partitions = self.generate_partitions(exported_program, graph_module)
189193

190194
# tag nodes
191195
partition_tags: Dict[str, DelegationSpec] = {}
@@ -199,6 +203,10 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult:
199203
node.meta["delegation_tag"] = delegation_tag
200204
partition_tags[delegation_tag] = self.delegation_spec
201205

206+
for _, submodule, _ in get_control_flow_submodules(graph_module):
207+
# pyre-ignore
208+
self.partition(exported_program, submodule)
209+
202210
return PartitionResult(
203211
tagged_exported_program=exported_program, partition_tags=partition_tags
204212
)

0 commit comments

Comments
 (0)