Skip to content

Commit 56e4457

Browse files
committed
[Group Partitioner] Optimize Speed
ghstack-source-id: 679bfdc ghstack-comment-id: 3115642422 Pull Request resolved: #12844
1 parent 21c8e67 commit 56e4457

File tree

1 file changed

+45
-16
lines changed

1 file changed

+45
-16
lines changed

exir/backend/canonical_partitioners/group_partitioner.py

Lines changed: 45 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ def __init__(
8686
)
8787
self.node_to_group = collections.defaultdict(int)
8888
self.all_nodes_in_groups = set()
89-
if node_groups:
89+
if self.node_groups:
9090
for i, group in enumerate(self.node_groups):
9191
for node in group:
9292
# Node is in multiple groups - not allowed
@@ -101,19 +101,25 @@ def _can_merge_partitions(self, p1, p2, partitions_by_id):
101101
p2_nodes = set(partitions_by_id[p2].nodes.keys())
102102
combined_nodes = p1_nodes.union(p2_nodes)
103103

104-
for node in combined_nodes:
105-
# Get all downstream nodes that are not in the combined partition
106-
external_downstreams = {
107-
n
108-
for n in self.dependency_viewer.downstreams_of(node)
109-
if n not in combined_nodes
110-
}
104+
user_nodes = []
105+
# topologically, p2_nodes comes before p1_nodes, so we only
106+
# need to check the downstream nodes of p2.
107+
# Additionally, we don't need to check all the downstream nodes
108+
# of p2, we only need to check the nodes directly outside of p2.
109+
# example:
110+
# partition[a --> b --> c] --> d --> e --> f
111+
# we don't need to check [d, e, f] we only need to check [d] because
112+
# the downstream users of [d] will include [e, f]
113+
for node in p2_nodes:
114+
for user in node.users:
115+
if user not in combined_nodes:
116+
user_nodes.append(user)
111117

118+
for external_node in user_nodes:
112119
# Check if any external downstream nodes have downstream nodes in the combined partition
113-
for external_node in external_downstreams:
114-
downstream_nodes = self.dependency_viewer.downstreams_of(external_node)
115-
if any(n in combined_nodes for n in downstream_nodes):
116-
return False
120+
downstream_nodes = self.dependency_viewer.downstreams_of(external_node)
121+
if any(n in combined_nodes for n in downstream_nodes):
122+
return False
117123

118124
return True
119125

@@ -133,13 +139,31 @@ def _process_node_groups(
133139
if not self.node_groups:
134140
return group_to_partition_id
135141

136-
for i, group in enumerate(self.node_groups):
137-
# Create a partition for each group
142+
143+
processed_nodes = set()
144+
145+
# We have to create the partitions in reverse topological order
146+
# so we find the groups as we traverse backwards in the graph
147+
# this likely needs to be combined with the process_remaining_nodes
148+
# TODO: this currently doesn't work with _process_remaining_nodes so
149+
# if a user provides grouped nodes with operatorsupport, then this will
150+
# faile
151+
for node in reversed(self.graph_module.graph.nodes):
152+
if node not in self.node_to_group:
153+
continue
154+
155+
if node in processed_nodes:
156+
continue
157+
158+
group_idx = self.node_to_group[node]
159+
group = self.node_groups[group_idx]
160+
161+
# Create a partition for group
138162
partition_id = next(new_partition_id)
139163
partition = Partition(id=partition_id, nodes=set())
140164
partitions_by_id[partition_id] = partition
141165
partitions_order[partition_id] = partition_id
142-
group_to_partition_id[i] = partition_id
166+
group_to_partition_id[group_idx] = partition_id
143167

144168
# Add all supported nodes from the group to the partition
145169
for node in group:
@@ -164,6 +188,12 @@ def _process_node_groups(
164188
partition_map[partition_id].add(target_id)
165189
partition_map[partition_id].update(partition_map[target_id])
166190

191+
# all the nodes in the group have now been processed
192+
# so skip if we encoutner them again in our rev topo
193+
# iteration
194+
for node in group:
195+
processed_nodes.add(node)
196+
167197
return group_to_partition_id
168198

169199
def _process_remaining_nodes(
@@ -209,7 +239,6 @@ def _merge_partitions(
209239

210240
# Set to track removed partitions from initial static list so we can skip them
211241
already_merged = set()
212-
213242
# Try to merge each pair of partitions
214243
for i, p1 in enumerate(partition_ids):
215244
# Skip if this partition has been already merged

0 commit comments

Comments
 (0)