Skip to content

Commit 0da475d

Browse files
committed
[Group Partitioner] leverage group partitioner for config-based partitioner
ghstack-source-id: b2b7dda ghstack-comment-id: 3115642963 Pull Request resolved: #12845
1 parent 00e3f99 commit 0da475d

File tree

3 files changed

+117
-10
lines changed

3 files changed

+117
-10
lines changed

exir/backend/canonical_partitioners/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ runtime.python_library(
1818
deps = [
1919
"//caffe2:torch",
2020
"//executorch/exir/backend:partitioner",
21+
":group_partitioner_lib",
2122
],
2223
)
2324

exir/backend/canonical_partitioners/config_partitioner.py

Lines changed: 68 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,16 +10,24 @@
1010
import torch
1111
from executorch.exir.backend.backend_details import ExportedProgram
1212
from executorch.exir.backend.canonical_partitioners.pattern_op_partitioner import (
13-
generate_partitions_from_list_of_nodes,
13+
generate_grouped_partitions_from_list_of_nodes,
1414
)
1515
from executorch.exir.backend.partitioner import (
1616
DelegationSpec,
1717
Partitioner,
1818
PartitionResult,
1919
)
20+
21+
from torch._export.utils import is_buffer, is_lifted_tensor_constant, is_param
2022
from torch.fx.passes.infra.partitioner import Partition
2123

2224

25+
def is_constant_data(ep: ExportedProgram, node: torch.fx.Node) -> bool:
26+
return (
27+
is_param(ep, node) or is_buffer(ep, node) or is_lifted_tensor_constant(ep, node)
28+
)
29+
30+
2331
def format_target_name(target_name: str) -> str:
2432
"""
2533
We remove the dialect name space from the target name. We generally
@@ -100,6 +108,35 @@ def get_partition(
100108
pass
101109

102110

111+
class DSJ:
112+
"""
113+
Disjoint set union data structure used to find connected components in the graph.
114+
"""
115+
116+
def __init__(self):
117+
self.parent = {}
118+
119+
def find(self, x):
120+
self.parent.setdefault(x, x)
121+
if self.parent[x] != x:
122+
self.parent[x] = self.find(self.parent[x])
123+
return self.parent[x]
124+
125+
def union(self, x, y):
126+
self.parent[self.find(x)] = self.find(y)
127+
128+
def contains(self, x):
129+
return x in self.parent
130+
131+
def gen_groups(self):
132+
groups = {}
133+
for node in self.parent.keys():
134+
root = self.find(node)
135+
groups.setdefault(root, set()).add(node)
136+
137+
return [list(group) for group in groups.values()]
138+
139+
103140
class ConfigerationBasedPartitioner(Partitioner):
104141
def __init__(
105142
self,
@@ -162,23 +199,44 @@ def filter_fn(node: torch.fx.Node) -> bool:
162199
def get_matched_nodes_from_configs(
163200
self, ep: ExportedProgram
164201
) -> List[List[torch.fx.Node]]:
202+
# disjoint set union for merging partitions
203+
dsj = DSJ()
204+
165205
# gather supported nodes
166-
matched_nodes = []
167206
gm = ep.graph_module
168207
for node in gm.graph.nodes:
169-
if node.op == "call_function":
170-
target = format_target_name(node.target.__name__)
171-
if target in self.target_partitioner_configs:
172-
node_config = self.target_partitioner_configs[target]
173-
if node_config.check_constraints(node, ep):
174-
matched_nodes.append(node_config.get_partition(node, ep))
208+
if node.op != "call_function":
209+
continue
210+
target = format_target_name(node.target.__name__)
211+
212+
if target not in self.target_partitioner_configs:
213+
continue
214+
215+
node_config = self.target_partitioner_configs[target]
216+
if not node_config.check_constraints(node, ep):
217+
continue
218+
219+
partition_candidate = node_config.get_partition(node, ep)
220+
partition = []
221+
for node in partition_candidate:
222+
# partitioner infra copies constant data across partitions, so it
223+
# is ok if this partition doesn't have it
224+
if is_constant_data(ep, node) and dsj.contains(node):
225+
continue
226+
partition.append(node)
227+
228+
# Union overlaps into a single group
229+
if len(partition) > 0:
230+
dsj.find(partition[0])
231+
for i in range(1, len(partition)):
232+
dsj.union(partition[0], partition[i])
175233

176-
return matched_nodes
234+
return dsj.gen_groups()
177235

178236
def generate_partitions(self, ep: ExportedProgram) -> List[Partition]:
179237
matched_nodes = self.get_matched_nodes_from_configs(ep)
180238
# create partitions
181-
partitions = generate_partitions_from_list_of_nodes(
239+
partitions = generate_grouped_partitions_from_list_of_nodes(
182240
ep.graph_module,
183241
matched_nodes,
184242
)

exir/backend/canonical_partitioners/pattern_op_partitioner.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,10 @@
88
from typing import List, Optional
99

1010
import torch
11+
12+
from executorch.exir.backend.canonical_partitioners.group_partitioner import (
13+
GroupBasedPartitioner,
14+
)
1115
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner, Partition
1216
from torch.fx.passes.operator_support import any_chain, OperatorSupportBase
1317
from torch.fx.passes.utils.matcher_utils import SubgraphMatcher
@@ -56,6 +60,50 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
5660
return partition_list
5761

5862

63+
def generate_grouped_partitions_from_list_of_nodes(
64+
graph_module: torch.fx.GraphModule,
65+
pattern_list: Optional[List[List[torch.fx.Node]]] = None,
66+
op_support: Optional[OperatorSupportBase] = None,
67+
) -> List[Partition]:
68+
final_op_support: Optional[OperatorSupportBase] = op_support
69+
70+
if pattern_list is not None:
71+
# Tag all the nodes in these patterns
72+
for node_list in pattern_list:
73+
for node in node_list:
74+
node.meta["match"] = True
75+
76+
class MatchTag(OperatorSupportBase):
77+
def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
78+
return node.meta.get("match", False)
79+
80+
final_op_support = (
81+
MatchTag()
82+
if final_op_support is None
83+
else any_chain(final_op_support, MatchTag())
84+
)
85+
86+
assert (
87+
final_op_support is not None
88+
), "Did not give a pattern or OperatorSupportBase instance to partition with"
89+
90+
# Run the CapabilityBasedPartitioner to return the largest possible
91+
# subgraphs containing the nodes with the tags
92+
group_partitioner = GroupBasedPartitioner(
93+
graph_module,
94+
final_op_support,
95+
node_groups=pattern_list,
96+
allows_single_node_partition=True,
97+
)
98+
partition_list = group_partitioner.propose_partitions()
99+
100+
# Remove the metadata field we added
101+
for partition in partition_list:
102+
for node in partition.nodes:
103+
node.meta.pop("match", False)
104+
return partition_list
105+
106+
59107
def generate_pattern_op_partitions(
60108
graph_module: torch.fx.GraphModule,
61109
patterns: Optional[List[torch.fx.Graph]] = None,

0 commit comments

Comments
 (0)