Skip to content

Commit 0134e88

Browse files
authored
optimize group based partitioner to not try to merge partitions we already determined were not supposed to be merged (#12798)
### Summary In the group based partitioner we had a case where we would try to merge partitions we already determined couldn't be merged, which was causing a lot more checks to be done. This pr creates a set that tracks the partitions that we determined cant be merged and ensures that when considering 2 new partitions either id isnt in this set. Find below the number of checks on whether each partition can be merged before and after respectively which has a runtime of O(n^2) where n represents the number of nodes in a partition. <img width="84" height="772" alt="image" src="https://github.com/user-attachments/assets/e1f7b47f-9006-494f-afe1-4da0726362f0" /> <img width="68" height="766" alt="image" src="https://github.com/user-attachments/assets/45d4dbe0-f292-4f80-8e66-53aa68ef3cdb" /> ### Test plan I added cnt print logs to check to see the number of times we entered the check both before and after the change.
1 parent 6afe388 commit 0134e88

File tree

1 file changed

+28
-26
lines changed

1 file changed

+28
-26
lines changed

exir/backend/canonical_partitioners/group_partitioner.py

Lines changed: 28 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -204,35 +204,37 @@ def _merge_partitions(
204204
partitions_order,
205205
):
206206
"""Merge partitions when possible."""
207-
merged = True
208-
while merged:
209-
merged = False
210-
partition_ids = list(partitions_by_id.keys())
207+
# Get current partition IDs
208+
partition_ids = list(partitions_by_id.keys())
211209

212-
for i, p1 in enumerate(partition_ids):
213-
if p1 not in partitions_by_id:
210+
# Set to track removed partitions from initial static list so we can skip them
211+
already_merged = set()
212+
213+
# Try to merge each pair of partitions
214+
for i, p1 in enumerate(partition_ids):
215+
# Skip if this partition has been already merged
216+
if p1 in already_merged:
217+
continue
218+
219+
for p2 in partition_ids[i + 1 :]:
220+
# Skip if this partition has been already merged
221+
if p2 in already_merged:
214222
continue
215223

216-
for p2 in partition_ids[i + 1 :]:
217-
if p2 not in partitions_by_id:
218-
continue
219-
220-
# Try to merge partitions if it doesn't create cycles
221-
if self._can_merge_partitions(p1, p2, partitions_by_id):
222-
self._perform_partition_merge(
223-
p1,
224-
p2,
225-
partitions_by_id,
226-
assignment,
227-
partition_users,
228-
partition_map,
229-
partitions_order,
230-
)
231-
merged = True
232-
break
233-
234-
if merged:
235-
break
224+
# Try to merge partitions if it doesn't create cycles
225+
if self._can_merge_partitions(p1, p2, partitions_by_id):
226+
self._perform_partition_merge(
227+
p1,
228+
p2,
229+
partitions_by_id,
230+
assignment,
231+
partition_users,
232+
partition_map,
233+
partitions_order,
234+
)
235+
236+
# Mark p2 as merged
237+
already_merged.add(p2)
236238

237239
def _perform_partition_merge(
238240
self,

0 commit comments

Comments
 (0)