Skip to content

Commit 4099117

Browse files
committed
Change node group partitioning to be with all nodes, to keep partition ids in top sort order
1 parent 66e5591 commit 4099117

File tree

1 file changed

+46
-94
lines changed

1 file changed

+46
-94
lines changed

exir/backend/canonical_partitioners/group_partitioner.py

Lines changed: 46 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ def _can_merge_partitions(self, p1, p2, partitions_by_id):
123123

124124
return True
125125

126-
def _process_node_groups(
126+
def _process_all_nodes(
127127
self,
128128
new_partition_id,
129129
partitions_by_id,
@@ -133,97 +133,60 @@ def _process_node_groups(
133133
partition_users,
134134
partition_map,
135135
):
136-
"""Process nodes in predefined groups."""
137-
group_to_partition_id = {}
138-
139-
if not self.node_groups:
140-
return group_to_partition_id
141-
142-
processed_nodes = set()
143-
144-
# We have to create the partitions in reverse topological order
145-
# so we find the groups as we traverse backwards in the graph
146-
# this likely needs to be combined with the process_remaining_nodes
147-
# TODO: this currently doesn't work with _process_remaining_nodes so
148-
# if a user provides grouped nodes with operatorsupport, then this will
149-
# faile
136+
"""Process nodes not in any predefined group."""
150137
for node in reversed(self.graph_module.graph.nodes):
151-
if node not in self.node_to_group:
138+
if node in assignment or not self._is_node_supported(node):
152139
continue
153140

154-
if node in processed_nodes:
155-
continue
141+
if node in self.all_nodes_in_groups:
142+
group_idx = self.node_to_group[node]
143+
group = self.node_groups[group_idx]
156144

157-
group_idx = self.node_to_group[node]
158-
group = self.node_groups[group_idx]
159-
160-
# Create a partition for group
161-
partition_id = next(new_partition_id)
162-
partition = Partition(id=partition_id, nodes=set())
163-
partitions_by_id[partition_id] = partition
164-
partitions_order[partition_id] = partition_id
165-
group_to_partition_id[group_idx] = partition_id
166-
167-
# Add all supported nodes from the group to the partition
168-
for node in group:
169-
if self._is_node_supported(node):
170-
partition.add_node(node)
171-
assignment[node] = partition_id
172-
nodes_order[node] = partition_id
173-
174-
# Set partition users
175-
partition_users[partition_id] = {
176-
user
177-
for node in partition.nodes
178-
for user in node.users
179-
if user not in partition.nodes
180-
}
181-
182-
# Update partition map
183-
for node in partition.nodes:
145+
# Create a partition for group
146+
partition_id = next(new_partition_id)
147+
partition = Partition(id=partition_id, nodes=set())
148+
partitions_by_id[partition_id] = partition
149+
partitions_order[partition_id] = partition_id
150+
151+
# Add all supported nodes from the group to the partition
152+
for node in group:
153+
if self._is_node_supported(node):
154+
partition.add_node(node)
155+
assignment[node] = partition_id
156+
nodes_order[node] = partition_id
157+
158+
# Set partition users
159+
partition_users[partition_id] = {
160+
user
161+
for node in partition.nodes
162+
for user in node.users
163+
if user not in partition.nodes
164+
}
165+
166+
# Update partition map
167+
for node in partition.nodes:
168+
for user in node.users:
169+
target_id = assignment.get(user, None)
170+
if target_id is not None and target_id != partition_id:
171+
partition_map[partition_id].add(target_id)
172+
partition_map[partition_id].update(partition_map[target_id])
173+
else:
174+
partition_id = next(new_partition_id)
175+
nodes_order[node] = partition_id
176+
partitions_order[partition_id] = partition_id
177+
partitions_by_id[partition_id] = Partition(
178+
id=partition_id, nodes=[node]
179+
)
180+
assignment[node] = partition_id
181+
partition_users[partition_id] = set(node.users)
182+
183+
# Update partition map
184184
for user in node.users:
185185
target_id = assignment.get(user)
186-
if target_id is not None and target_id != partition_id:
186+
if target_id is not None:
187187
partition_map[partition_id].add(target_id)
188188
partition_map[partition_id].update(partition_map[target_id])
189189

190-
# all the nodes in the group have now been processed
191-
# so skip if we encoutner them again in our rev topo
192-
# iteration
193-
for node in group:
194-
processed_nodes.add(node)
195-
196-
return group_to_partition_id
197-
198-
def _process_remaining_nodes(
199-
self,
200-
new_partition_id,
201-
partitions_by_id,
202-
assignment,
203-
nodes_order,
204-
partitions_order,
205-
partition_users,
206-
partition_map,
207-
):
208-
"""Process nodes not in any predefined group."""
209-
for node in reversed(self.graph_module.graph.nodes):
210-
if node in assignment or not self._is_node_supported(node):
211-
continue
212-
213-
partition_id = next(new_partition_id)
214-
nodes_order[node] = partition_id
215-
partitions_order[partition_id] = partition_id
216-
partitions_by_id[partition_id] = Partition(id=partition_id, nodes=[node])
217-
assignment[node] = partition_id
218-
partition_users[partition_id] = set(node.users)
219-
220-
# Update partition map
221-
for user in node.users:
222-
target_id = assignment.get(user)
223-
if target_id is not None:
224-
partition_map[partition_id].add(target_id)
225-
partition_map[partition_id].update(partition_map[target_id])
226-
227190
def _merge_partitions(
228191
self,
229192
partitions_by_id,
@@ -378,17 +341,6 @@ def propose_partitions(self) -> list[Partition]:
378341
partition_users = {} # Maps partition IDs to partition users
379342
new_partition_id = itertools.count()
380343

381-
# Process nodes in predefined groups
382-
self._process_node_groups(
383-
new_partition_id,
384-
partitions_by_id,
385-
assignment,
386-
nodes_order,
387-
partitions_order,
388-
partition_users,
389-
partition_map,
390-
)
391-
392344
# Process remaining nodes
393345
self._process_remaining_nodes(
394346
new_partition_id,

0 commit comments

Comments
 (0)