Skip to content

Commit ab4b04c

Browse files
committed
[Group Partitioner] leverage group partitioner for config-based partitioner
ghstack-source-id: 5507008 ghstack-comment-id: 3115642963 Pull Request resolved: #12845
1 parent 0abe452 commit ab4b04c

File tree

2 files changed

+117
-10
lines changed

2 files changed

+117
-10
lines changed

exir/backend/canonical_partitioners/config_partitioner.py

Lines changed: 69 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,16 +10,25 @@
1010
import torch
1111
from executorch.exir.backend.backend_details import ExportedProgram
1212
from executorch.exir.backend.canonical_partitioners.pattern_op_partitioner import (
13-
generate_partitions_from_list_of_nodes,
13+
generate_grouped_partitions_from_list_of_nodes,
1414
)
1515
from executorch.exir.backend.partitioner import (
1616
DelegationSpec,
1717
Partitioner,
1818
PartitionResult,
1919
)
20+
21+
from sympy.logic.boolalg import disjuncts
22+
from torch._export.utils import is_buffer, is_lifted_tensor_constant, is_param
2023
from torch.fx.passes.infra.partitioner import Partition
2124

2225

26+
def is_constant_data(ep: ExportedProgram, node: torch.fx.Node) -> bool:
27+
return (
28+
is_param(ep, node) or is_buffer(ep, node) or is_lifted_tensor_constant(ep, node)
29+
)
30+
31+
2332
def format_target_name(target_name: str) -> str:
2433
"""
2534
We remove the dialect name space from the target name. We generally
@@ -100,6 +109,35 @@ def get_partition(
100109
pass
101110

102111

112+
class DSJ:
113+
"""
114+
Disjoint set union data structure used to find connected components in the graph.
115+
"""
116+
117+
def __init__(self):
118+
self.parent = {}
119+
120+
def find(self, x):
121+
self.parent.setdefault(x, x)
122+
if self.parent[x] != x:
123+
self.parent[x] = self.find(self.parent[x])
124+
return self.parent[x]
125+
126+
def union(self, x, y):
127+
self.parent[self.find(x)] = self.find(y)
128+
129+
def contains(self, x):
130+
return x in self.parent
131+
132+
def gen_groups(self):
133+
groups = {}
134+
for node in self.parent.keys():
135+
root = self.find(node)
136+
groups.setdefault(root, set()).add(node)
137+
138+
return [list(group) for group in groups.values()]
139+
140+
103141
class ConfigerationBasedPartitioner(Partitioner):
104142
def __init__(
105143
self,
@@ -162,23 +200,44 @@ def filter_fn(node: torch.fx.Node) -> bool:
162200
def get_matched_nodes_from_configs(
163201
self, ep: ExportedProgram
164202
) -> List[List[torch.fx.Node]]:
203+
# disjoint set union for merging partitions
204+
dsj = DSJ()
205+
165206
# gather supported nodes
166-
matched_nodes = []
167207
gm = ep.graph_module
168208
for node in gm.graph.nodes:
169-
if node.op == "call_function":
170-
target = format_target_name(node.target.__name__)
171-
if target in self.target_partitioner_configs:
172-
node_config = self.target_partitioner_configs[target]
173-
if node_config.check_constraints(node, ep):
174-
matched_nodes.append(node_config.get_partition(node, ep))
209+
if node.op != "call_function":
210+
continue
211+
target = format_target_name(node.target.__name__)
212+
213+
if target not in self.target_partitioner_configs:
214+
continue
215+
216+
node_config = self.target_partitioner_configs[target]
217+
if not node_config.check_constraints(node, ep):
218+
continue
219+
220+
partition_candidate = node_config.get_partition(node, ep)
221+
partition = []
222+
for node in partition_candidate:
223+
# partitioner infra copies constant data across partitions, so it
224+
# is ok if this partition doesn't have it
225+
if is_constant_data(ep, node) and dsj.contains(node):
226+
continue
227+
partition.append(node)
228+
229+
# Union overlaps into a single group
230+
if len(partition) > 0:
231+
dsj.find(partition[0])
232+
for i in range(1, len(partition)):
233+
dsj.union(partition[0], partition[i])
175234

176-
return matched_nodes
235+
return dsj.gen_groups()
177236

178237
def generate_partitions(self, ep: ExportedProgram) -> List[Partition]:
179238
matched_nodes = self.get_matched_nodes_from_configs(ep)
180239
# create partitions
181-
partitions = generate_partitions_from_list_of_nodes(
240+
partitions = generate_grouped_partitions_from_list_of_nodes(
182241
ep.graph_module,
183242
matched_nodes,
184243
)

exir/backend/canonical_partitioners/pattern_op_partitioner.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,10 @@
88
from typing import List, Optional
99

1010
import torch
11+
12+
from executorch.exir.backend.canonical_partitioners.group_partitioner import (
13+
GroupBasedPartitioner,
14+
)
1115
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner, Partition
1216
from torch.fx.passes.operator_support import any_chain, OperatorSupportBase
1317
from torch.fx.passes.utils.matcher_utils import SubgraphMatcher
@@ -56,6 +60,50 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
5660
return partition_list
5761

5862

63+
def generate_grouped_partitions_from_list_of_nodes(
64+
graph_module: torch.fx.GraphModule,
65+
pattern_list: Optional[List[List[torch.fx.Node]]] = None,
66+
op_support: Optional[OperatorSupportBase] = None,
67+
) -> List[Partition]:
68+
final_op_support: Optional[OperatorSupportBase] = op_support
69+
70+
if pattern_list is not None:
71+
# Tag all the nodes in these patterns
72+
for node_list in pattern_list:
73+
for node in node_list:
74+
node.meta["match"] = True
75+
76+
class MatchTag(OperatorSupportBase):
77+
def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
78+
return node.meta.get("match", False)
79+
80+
final_op_support = (
81+
MatchTag()
82+
if final_op_support is None
83+
else any_chain(final_op_support, MatchTag())
84+
)
85+
86+
assert (
87+
final_op_support is not None
88+
), "Did not give a pattern or OperatorSupportBase instance to partition with"
89+
90+
# Run the CapabilityBasedPartitioner to return the largest possible
91+
# subgraphs containing the nodes with the tags
92+
group_partitioner = GroupBasedPartitioner(
93+
graph_module,
94+
final_op_support,
95+
node_groups=pattern_list,
96+
allows_single_node_partition=True,
97+
)
98+
partition_list = group_partitioner.propose_partitions()
99+
100+
# Remove the metadata field we added
101+
for partition in partition_list:
102+
for node in partition.nodes:
103+
node.meta.pop("match", False)
104+
return partition_list
105+
106+
59107
def generate_pattern_op_partitions(
60108
graph_module: torch.fx.GraphModule,
61109
patterns: Optional[List[torch.fx.Graph]] = None,

0 commit comments

Comments
 (0)