Skip to content

Commit 2034cdd

Browse files
committed
Update
[ghstack-poisoned]
1 parent d669123 commit 2034cdd

File tree

1 file changed

+53
-20
lines changed

1 file changed

+53
-20
lines changed

exir/backend/canonical_partitioners/config_partitioner.py

Lines changed: 53 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,18 @@
1717
Partitioner,
1818
PartitionResult,
1919
)
20+
21+
from sympy.logic.boolalg import disjuncts
22+
from torch._export.utils import is_buffer, is_lifted_tensor_constant, is_param
2023
from torch.fx.passes.infra.partitioner import Partition
2124

2225

26+
def is_constant_data(ep: ExportedProgram, node: torch.fx.Node) -> bool:
27+
return (
28+
is_param(ep, node) or is_buffer(ep, node) or is_lifted_tensor_constant(ep, node)
29+
)
30+
31+
2332
def format_target_name(target_name: str) -> str:
2433
"""
2534
We remove the dialect name space from the target name. We generally
@@ -100,6 +109,35 @@ def get_partition(
100109
pass
101110

102111

112+
class DSJ:
113+
"""
114+
Disjoint set union data structure used to find connected components in the graph.
115+
"""
116+
117+
def __init__(self):
118+
self.parent = {}
119+
120+
def find(self, x):
121+
self.parent.setdefault(x, x)
122+
if self.parent[x] != x:
123+
self.parent[x] = self.find(self.parent[x])
124+
return self.parent[x]
125+
126+
def union(self, x, y):
127+
self.parent[self.find(x)] = self.find(y)
128+
129+
def contains(self, x):
130+
return x in self.parent
131+
132+
def gen_groups(self):
133+
groups = {}
134+
for node in self.parent.keys():
135+
root = self.find(node)
136+
groups.setdefault(root, set()).add(node)
137+
138+
return [list(group) for group in groups.values()]
139+
140+
103141
class ConfigerationBasedPartitioner(Partitioner):
104142
def __init__(
105143
self,
@@ -162,17 +200,8 @@ def filter_fn(node: torch.fx.Node) -> bool:
162200
def get_matched_nodes_from_configs(
163201
self, ep: ExportedProgram
164202
) -> List[List[torch.fx.Node]]:
165-
# disjoint set union
166-
parent = {}
167-
168-
def find(x):
169-
parent.setdefault(x, x)
170-
if parent[x] != x:
171-
parent[x] = find(parent[x])
172-
return parent[x]
173-
174-
def union(x, y):
175-
parent[find(x)] = find(y)
203+
# disjoint set union for merging partitions
204+
dsj = DSJ()
176205

177206
# gather supported nodes
178207
gm = ep.graph_module
@@ -188,18 +217,22 @@ def union(x, y):
188217
if not node_config.check_constraints(node, ep):
189218
continue
190219

191-
partition = node_config.get_partition(node, ep)
220+
partition_candidate = node_config.get_partition(node, ep)
221+
partition = []
222+
for node in partition_candidate:
223+
# partitioner infra copies constant data across partitions, so it
224+
# is ok if this partition doesn't have it
225+
if is_constant_data(ep, node) and dsj.contains(node):
226+
continue
227+
partition.append(node)
228+
229+
# Union overlaps into a single group
192230
if len(partition) > 0:
193-
parent[partition[0]] = partition[0]
231+
dsj.find(partition[0])
194232
for i in range(1, len(partition)):
195-
union(partition[0], partition[i])
233+
dsj.union(partition[0], partition[i])
196234

197-
groups = {}
198-
for node in parent.keys():
199-
root = find(node)
200-
groups.setdefault(root, set()).add(node)
201-
202-
return [list(group) for group in groups.values()]
235+
return dsj.gen_groups()
203236

204237
def generate_partitions(self, ep: ExportedProgram) -> List[Partition]:
205238
matched_nodes = self.get_matched_nodes_from_configs(ep)

0 commit comments

Comments
 (0)