Skip to content

Commit 7ffa511

Browse files
mansiag05pytorchmergebot
authored andcommitted
[Distributed] Optimize ND shard overlap detection (pytorch#167073)
* Fixes high quadratic cost in `validate_non_overlapping_shards_metadata` when shard count is large by replacing the O(n²) nested-loop scan in `_find_nd_overlapping_shards` with a sweep-line pass giving O(n log n) behavior for ND overlap detection. * Add test cases in `test_check_overlapping` covering 2D grid patterns, adjacent shards, and 3D multi-shard overlap scenarios to validate the optimized path. Fixes pytorch#166941 Pull Request resolved: pytorch#167073 Approved by: https://github.com/Skylion007, https://github.com/wconstab
1 parent 25a64df commit 7ffa511

File tree

2 files changed

+109
-29
lines changed

2 files changed

+109
-29
lines changed

test/distributed/_shard/sharding_spec/test_sharding_spec.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -490,6 +490,69 @@ def test_check_overlapping(self):
490490
with self.assertRaisesRegex(ValueError, "overlap"):
491491
validate_non_overlapping_shards_metadata(shards)
492492

493+
shards = [
494+
ShardMetadata(
495+
shard_offsets=[0, 0],
496+
shard_sizes=[5, 5],
497+
placement="cuda:0",
498+
),
499+
ShardMetadata(
500+
shard_offsets=[0, 5],
501+
shard_sizes=[5, 5],
502+
placement="cuda:1",
503+
),
504+
ShardMetadata(
505+
shard_offsets=[5, 0],
506+
shard_sizes=[5, 5],
507+
placement="cuda:2",
508+
),
509+
ShardMetadata(
510+
shard_offsets=[5, 5],
511+
shard_sizes=[5, 5],
512+
placement="cuda:3",
513+
),
514+
]
515+
validate_non_overlapping_shards_metadata(shards)
516+
517+
shards = [
518+
ShardMetadata(
519+
shard_offsets=[0, 0],
520+
shard_sizes=[5, 5],
521+
placement="cuda:0",
522+
),
523+
ShardMetadata(
524+
shard_offsets=[5, 5],
525+
shard_sizes=[5, 5],
526+
placement="cuda:1",
527+
),
528+
]
529+
validate_non_overlapping_shards_metadata(shards)
530+
531+
shards = [
532+
ShardMetadata(
533+
shard_offsets=[0, 0, 0],
534+
shard_sizes=[5, 5, 5],
535+
placement="cuda:0",
536+
),
537+
ShardMetadata(
538+
shard_offsets=[5, 0, 0],
539+
shard_sizes=[5, 5, 5],
540+
placement="cuda:1",
541+
),
542+
ShardMetadata(
543+
shard_offsets=[10, 0, 0],
544+
shard_sizes=[5, 5, 5],
545+
placement="cuda:2",
546+
),
547+
ShardMetadata(
548+
shard_offsets=[10, 3, 0],
549+
shard_sizes=[5, 5, 5],
550+
placement="cuda:3",
551+
),
552+
]
553+
with self.assertRaisesRegex(ValueError, "overlap"):
554+
validate_non_overlapping_shards_metadata(shards)
555+
493556

494557
# Custom ShardingSpec, an simple example to do grid sharding
495558
@dataclass

torch/distributed/_shard/sharding_spec/_internals.py

Lines changed: 46 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
# mypy: allow-untyped-defs
22
import math
3+
import sys
4+
from bisect import bisect_right, insort
35
from typing import Optional
46

57
from torch.distributed._shard.metadata import ShardMetadata
@@ -27,31 +29,48 @@ def _check_shard_metadata_pair_overlap(shard1: ShardMetadata, shard2: ShardMetad
2729
def _find_nd_overlapping_shards(
2830
shards: list[ShardMetadata], sharded_dims: list[int]
2931
) -> Optional[tuple[int, int]]:
30-
# Each rank has len(sharded_dims) tuples. Each tuple represent the
31-
# [begin, end] (inclusive) pair of that dimension.
32-
shard_intervals = [
33-
[
34-
(s.shard_offsets[dim], s.shard_offsets[dim] + s.shard_sizes[dim] - 1)
35-
for dim in sharded_dims
36-
]
37-
for s in shards
38-
]
39-
40-
for i in range(len(shards)):
41-
shard_i = shard_intervals[i]
42-
for j in range(i + 1, len(shards)):
43-
shard_j = shard_intervals[j]
44-
# For each dim of each shard, check if one shard resides on the other
45-
# end of second shard with respect to that dim. As an example for a 2D
46-
# shard, we would check if one shard is above or on the left of the
47-
# other shard.
48-
overlap = True
49-
for interval_i, interval_j in zip(shard_i, shard_j):
50-
if interval_i[0] > interval_j[1] or interval_j[0] > interval_i[1]:
51-
overlap = False
52-
break
53-
if overlap:
54-
return (i, j)
32+
"""Find overlapping shards using sweep-line algorithm."""
33+
if len(shards) <= 1:
34+
return None
35+
36+
dims = len(sharded_dims)
37+
if dims == 0:
38+
return None
39+
40+
sweep_dim_idx = 0
41+
if dims > 1:
42+
max_size = 0
43+
for i, dim in enumerate(sharded_dims):
44+
dim_size = shards[0].shard_offsets[dim] + shards[0].shard_sizes[dim]
45+
if dim_size > max_size:
46+
max_size = dim_size
47+
sweep_dim_idx = i
48+
sweep_dim = sharded_dims[sweep_dim_idx]
49+
50+
sorted_indices = sorted(
51+
range(len(shards)),
52+
key=lambda idx: (
53+
shards[idx].shard_offsets[sweep_dim],
54+
*(shards[idx].shard_offsets[d] for d in sharded_dims if d != sweep_dim),
55+
),
56+
)
57+
active: list[tuple[int, int]] = []
58+
59+
for idx in sorted_indices:
60+
current = shards[idx]
61+
start = current.shard_offsets[sweep_dim]
62+
end = start + current.shard_sizes[sweep_dim]
63+
64+
cutoff = bisect_right(active, (start, sys.maxsize))
65+
if cutoff:
66+
del active[:cutoff]
67+
68+
for _, other_idx in active:
69+
other = shards[other_idx]
70+
71+
if _check_shard_metadata_pair_overlap(current, other):
72+
return (other_idx, idx)
73+
insort(active, (end, idx))
5574
return None
5675

5776

@@ -112,10 +131,8 @@ def validate_non_overlapping_shards_metadata(shards: list[ShardMetadata]):
112131
# using a O(nlogn) overlapping interval algorithm.
113132
pair = _find_1d_overlapping_shards(shards, sharded_dims[0])
114133
else:
115-
# Shards are partitioned over more than one dimension. Fall back to
116-
# pair-wise check. Even though O(nlogn) algorithms (line sweep) exist
117-
# for 2D overlap, the implementation is not trivial and may not justify
118-
# the time saving in most cases.
134+
# Shards are partitioned over more than one dimension.
135+
# Use sweep-line algorithm for O(n log n) complexity.
119136
pair = _find_nd_overlapping_shards(shards, sharded_dims)
120137

121138
if pair:

0 commit comments

Comments
 (0)