Skip to content

Commit b8fe100

Browse files
authored
Change node group partitioning to be with all nodes, to keep partition ids in top sort order (#12871)
### Summary With the current partitioner implementation, we need to traverse the partitions in reverse top sort order for the merging algorithm, as we have some invariants when doing cycle detection and checking downstream dependencies. The current problem is that we first form the group partitions and give them incrementing ids. Since they aren't guaranteed to be in top sort order, when we assign the remaining partitions to the ungrouped nodes, their ids will start from where the grouped ids left off, breaking our invariant. This PR assigns ids in order of reverse top sort, where if we encounter a group node we create the partition for the whole group and if we just encounter ungrouped nodes we create the partition just for that node. ### Test plan All previous tests for functionality still pass
1 parent 7820023 commit b8fe100

File tree

1 file changed

+48
-96
lines changed

1 file changed

+48
-96
lines changed

exir/backend/canonical_partitioners/group_partitioner.py

Lines changed: 48 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ def _can_merge_partitions(self, p1, p2, partitions_by_id):
123123

124124
return True
125125

126-
def _process_node_groups(
126+
def _process_all_nodes(
127127
self,
128128
new_partition_id,
129129
partitions_by_id,
@@ -133,97 +133,60 @@ def _process_node_groups(
133133
partition_users,
134134
partition_map,
135135
):
136-
"""Process nodes in predefined groups."""
137-
group_to_partition_id = {}
138-
139-
if not self.node_groups:
140-
return group_to_partition_id
141-
142-
processed_nodes = set()
143-
144-
# We have to create the partitions in reverse topological order
145-
# so we find the groups as we traverse backwards in the graph
146-
# this likely needs to be combined with the process_remaining_nodes
147-
# TODO: this currently doesn't work with _process_remaining_nodes so
148-
# if a user provides grouped nodes with operatorsupport, then this will
149-
# faile
136+
"""Process nodes into a partition."""
150137
for node in reversed(self.graph_module.graph.nodes):
151-
if node not in self.node_to_group:
138+
if node in assignment or not self._is_node_supported(node):
152139
continue
153140

154-
if node in processed_nodes:
155-
continue
141+
if node in self.all_nodes_in_groups:
142+
group_idx = self.node_to_group[node]
143+
group = self.node_groups[group_idx]
156144

157-
group_idx = self.node_to_group[node]
158-
group = self.node_groups[group_idx]
159-
160-
# Create a partition for group
161-
partition_id = next(new_partition_id)
162-
partition = Partition(id=partition_id, nodes=set())
163-
partitions_by_id[partition_id] = partition
164-
partitions_order[partition_id] = partition_id
165-
group_to_partition_id[group_idx] = partition_id
166-
167-
# Add all supported nodes from the group to the partition
168-
for node in group:
169-
if self._is_node_supported(node):
170-
partition.add_node(node)
171-
assignment[node] = partition_id
172-
nodes_order[node] = partition_id
173-
174-
# Set partition users
175-
partition_users[partition_id] = {
176-
user
177-
for node in partition.nodes
178-
for user in node.users
179-
if user not in partition.nodes
180-
}
181-
182-
# Update partition map
183-
for node in partition.nodes:
145+
# Create a partition for group
146+
partition_id = next(new_partition_id)
147+
partition = Partition(id=partition_id, nodes=set())
148+
partitions_by_id[partition_id] = partition
149+
partitions_order[partition_id] = partition_id
150+
151+
# Add all supported nodes from the group to the partition
152+
for node in group:
153+
if self._is_node_supported(node):
154+
partition.add_node(node)
155+
assignment[node] = partition_id
156+
nodes_order[node] = partition_id
157+
158+
# Set partition users
159+
partition_users[partition_id] = {
160+
user
161+
for node in partition.nodes
162+
for user in node.users
163+
if user not in partition.nodes
164+
}
165+
166+
# Update partition map
167+
for node in partition.nodes:
168+
for user in node.users:
169+
target_id = assignment.get(user, None)
170+
if target_id is not None and target_id != partition_id:
171+
partition_map[partition_id].add(target_id)
172+
partition_map[partition_id].update(partition_map[target_id])
173+
else:
174+
partition_id = next(new_partition_id)
175+
nodes_order[node] = partition_id
176+
partitions_order[partition_id] = partition_id
177+
partitions_by_id[partition_id] = Partition(
178+
id=partition_id, nodes=[node]
179+
)
180+
assignment[node] = partition_id
181+
partition_users[partition_id] = set(node.users)
182+
183+
# Update partition map
184184
for user in node.users:
185185
target_id = assignment.get(user)
186-
if target_id is not None and target_id != partition_id:
186+
if target_id is not None:
187187
partition_map[partition_id].add(target_id)
188188
partition_map[partition_id].update(partition_map[target_id])
189189

190-
# all the nodes in the group have now been processed
191-
# so skip if we encoutner them again in our rev topo
192-
# iteration
193-
for node in group:
194-
processed_nodes.add(node)
195-
196-
return group_to_partition_id
197-
198-
def _process_remaining_nodes(
199-
self,
200-
new_partition_id,
201-
partitions_by_id,
202-
assignment,
203-
nodes_order,
204-
partitions_order,
205-
partition_users,
206-
partition_map,
207-
):
208-
"""Process nodes not in any predefined group."""
209-
for node in reversed(self.graph_module.graph.nodes):
210-
if node in assignment or not self._is_node_supported(node):
211-
continue
212-
213-
partition_id = next(new_partition_id)
214-
nodes_order[node] = partition_id
215-
partitions_order[partition_id] = partition_id
216-
partitions_by_id[partition_id] = Partition(id=partition_id, nodes=[node])
217-
assignment[node] = partition_id
218-
partition_users[partition_id] = set(node.users)
219-
220-
# Update partition map
221-
for user in node.users:
222-
target_id = assignment.get(user)
223-
if target_id is not None:
224-
partition_map[partition_id].add(target_id)
225-
partition_map[partition_id].update(partition_map[target_id])
226-
227190
def _merge_partitions(
228191
self,
229192
partitions_by_id,
@@ -378,19 +341,8 @@ def propose_partitions(self) -> list[Partition]:
378341
partition_users = {} # Maps partition IDs to partition users
379342
new_partition_id = itertools.count()
380343

381-
# Process nodes in predefined groups
382-
self._process_node_groups(
383-
new_partition_id,
384-
partitions_by_id,
385-
assignment,
386-
nodes_order,
387-
partitions_order,
388-
partition_users,
389-
partition_map,
390-
)
391-
392-
# Process remaining nodes
393-
self._process_remaining_nodes(
344+
# Process all nodes into partitions
345+
self._process_all_nodes(
394346
new_partition_id,
395347
partitions_by_id,
396348
assignment,

0 commit comments

Comments
 (0)