Skip to content

Commit 75ec61f

Browse files
committed
address copilot feedback
1 parent 2080f13 commit 75ec61f

File tree

3 files changed

+54
-44
lines changed

3 files changed

+54
-44
lines changed

iris/_distributed_helpers.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -248,12 +248,11 @@ def extract_group_info(group, rank, num_ranks):
248248

249249
group_ranks = dist.get_process_group_ranks(group)
250250
world_size = len(group_ranks)
251-
rank_global = dist.get_rank()
251+
rank_global = rank
252252

253253
if rank_global not in group_ranks:
254254
raise RuntimeError(
255-
f"Current rank {rank_global} is not part of the specified process group. "
256-
f"Group contains ranks: {group_ranks}"
255+
f"Rank {rank_global} is not part of the specified process group. Group contains ranks: {group_ranks}"
257256
)
258257

259258
rank_in_group = group_ranks.index(rank_global)
@@ -315,16 +314,17 @@ def _device_barrier_kernel(
315314
MAX_SPINS: tl.constexpr = 1_000_000_000,
316315
):
317316
"""
318-
Stateless device-side barrier using atomic operations on the symmetric heap.
317+
Device-side barrier using atomic operations on the symmetric heap.
318+
CUDA graph capturable.
319+
320+
Stateless w.r.t. host-side epoch tracking: there is no CPU-side epoch
321+
counter. Each rank's flag on the heap serves as its own epoch counter,
322+
managed entirely by the GPU via atomic_add. A persistent per-group flags
323+
tensor is cached in ``_device_barrier_state``.
319324
320325
Launched with grid=(1,). A single CTA:
321326
1. Atomically increments its own flag (atomic_add, release)
322327
2. Serially polls each remote rank's flag for the same value (acquire)
323-
324-
No CPU-side epoch tracking. Each rank's flag IS the epoch, managed
325-
entirely on the GPU via atomic_add. This makes the barrier safe for
326-
CUDA graph capture: during recording the kernel is just recorded,
327-
during replay all ranks increment together.
328328
"""
329329
# Increment own flag and determine target
330330
own_flag_ptr = flags_ptr + iris_rank
@@ -355,15 +355,17 @@ def _device_barrier_kernel(
355355

356356
def distributed_device_barrier(flags, group, rank, num_ranks, heap_bases):
357357
"""
358-
Stateless device-side barrier using atomic operations on the symmetric heap.
358+
Device-side barrier using atomic operations on the symmetric heap.
359+
CUDA graph capturable.
359360
360361
Unlike ``distributed_barrier`` which uses host-side ``torch.distributed.barrier()``,
361362
this launches a single-CTA Triton kernel that synchronizes via
362363
device-side atomics, making it safe to use during CUDA graph capture.
363364
364-
No CPU-side epoch tracking is needed. Each rank's flag on the symmetric
365-
heap serves as its own epoch counter, managed entirely by the GPU via
366-
atomic_add.
365+
Stateless w.r.t. host-side epoch tracking: each rank's flag on the
366+
symmetric heap serves as its own epoch counter, managed entirely by
367+
the GPU via atomic_add. A persistent per-group flags tensor is cached
368+
in ``_device_barrier_state``.
367369
368370
Args:
369371
flags: int32 tensor on symmetric heap, one element per rank.

iris/iris.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -996,12 +996,14 @@ def barrier(self, stream=None, group=None):
996996

997997
def device_barrier(self, group=None):
998998
"""
999-
Stateless device-side barrier that is CUDA graph capturable.
999+
Device-side barrier that is CUDA graph capturable.
10001000
10011001
Unlike ``barrier()`` which uses host-side ``torch.distributed.barrier()``,
10021002
this uses device-side atomic operations on the symmetric heap to synchronize
1003-
ranks. No CPU-side epoch tracking -- each rank's flag on the heap serves
1004-
as its own epoch counter, managed entirely by the GPU via atomic_add.
1003+
ranks. Stateless w.r.t. host-side epoch tracking: each rank's flag on
1004+
the heap serves as its own epoch counter, managed entirely by the GPU
1005+
via atomic_add. A persistent per-group flags tensor is cached in
1006+
``_device_barrier_state``.
10051007
10061008
Args:
10071009
group (ProcessGroup, optional): The process group to synchronize.

tests/unittests/test_barriers.py

Lines changed: 34 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -55,22 +55,23 @@ def _write_remote_kernel(
5555
@pytest.mark.parametrize("barrier_type", BARRIER_TYPES)
5656
def test_barrier_basic(barrier_type, n):
5757
shmem = iris.iris(1 << 20)
58-
shmem.barrier()
58+
_call_barrier(shmem, barrier_type)
5959

6060
try:
6161
for _ in range(n):
6262
_call_barrier(shmem, barrier_type)
6363
finally:
64-
shmem.barrier()
64+
_call_barrier(shmem, barrier_type)
6565
del shmem
6666
gc.collect()
6767

6868

6969
@pytest.mark.parametrize("n", [1, 2, 5, 10])
70-
def test_barrier_state_reuse(n):
70+
@pytest.mark.parametrize("barrier_type", BARRIER_TYPES)
71+
def test_barrier_state_reuse(barrier_type, n):
7172
"""Verify device barrier reuses the same flags tensor across calls."""
7273
shmem = iris.iris(1 << 20)
73-
shmem.barrier()
74+
_call_barrier(shmem, barrier_type)
7475

7576
try:
7677
shmem.device_barrier()
@@ -82,7 +83,7 @@ def test_barrier_state_reuse(n):
8283
shmem.device_barrier()
8384
assert shmem._device_barrier_state[None].data_ptr() == flags_ptr
8485
finally:
85-
shmem.barrier()
86+
_call_barrier(shmem, barrier_type)
8687
del shmem
8788
gc.collect()
8889

@@ -161,13 +162,13 @@ def _cross_rank_graph(
161162
buf,
162163
result,
163164
):
164-
stream = torch.cuda.Stream()
165+
capture_stream = torch.cuda.Stream()
165166

166167
if op == "load":
167168
buf.fill_(float(rank))
168169

169170
# Warmup on capture stream.
170-
with torch.cuda.stream(stream):
171+
with torch.cuda.stream(capture_stream):
171172
for _ in range(num_barriers):
172173
shmem.device_barrier()
173174
_read_remote_kernel[(1,)](
@@ -180,11 +181,11 @@ def _cross_rank_graph(
180181
)
181182
for _ in range(num_barriers):
182183
shmem.device_barrier()
183-
stream.synchronize()
184+
capture_stream.synchronize()
184185

185186
# Capture.
186187
graph = torch.cuda.CUDAGraph()
187-
with torch.cuda.graph(graph, stream=stream):
188+
with torch.cuda.graph(graph, stream=capture_stream):
188189
for _ in range(num_barriers):
189190
shmem.device_barrier()
190191
_read_remote_kernel[(1,)](
@@ -201,11 +202,11 @@ def _cross_rank_graph(
201202
# Replay with fresh data.
202203
for i in range(rounds):
203204
val = float(rank + (i + 1) * 10)
204-
buf.fill_(val)
205-
shmem.device_barrier()
206-
207-
graph.replay()
208-
stream.synchronize()
205+
with torch.cuda.stream(capture_stream):
206+
buf.fill_(val)
207+
shmem.device_barrier()
208+
graph.replay()
209+
capture_stream.synchronize()
209210

210211
expected = torch.full(
211212
(N,),
@@ -218,7 +219,7 @@ def _cross_rank_graph(
218219
buf.fill_(0.0)
219220

220221
# Warmup on capture stream.
221-
with torch.cuda.stream(stream):
222+
with torch.cuda.stream(capture_stream):
222223
for _ in range(num_barriers):
223224
shmem.device_barrier()
224225
_write_remote_kernel[(1,)](
@@ -231,11 +232,11 @@ def _cross_rank_graph(
231232
)
232233
for _ in range(num_barriers):
233234
shmem.device_barrier()
234-
stream.synchronize()
235+
capture_stream.synchronize()
235236

236237
# Capture.
237238
graph = torch.cuda.CUDAGraph()
238-
with torch.cuda.graph(graph, stream=stream):
239+
with torch.cuda.graph(graph, stream=capture_stream):
239240
for _ in range(num_barriers):
240241
shmem.device_barrier()
241242
_write_remote_kernel[(1,)](
@@ -251,13 +252,15 @@ def _cross_rank_graph(
251252

252253
# Replay and verify.
253254
for _ in range(rounds):
254-
buf.fill_(0.0)
255-
shmem.device_barrier()
256-
257-
graph.replay()
258-
stream.synchronize()
255+
with torch.cuda.stream(capture_stream):
256+
buf.fill_(0.0)
257+
shmem.device_barrier()
258+
graph.replay()
259+
capture_stream.synchronize()
259260

260-
shmem.device_barrier()
261+
with torch.cuda.stream(capture_stream):
262+
shmem.device_barrier()
263+
capture_stream.synchronize()
261264
expected = torch.full((N,), float(writer), dtype=torch.float32, device="cuda")
262265
torch.testing.assert_close(buf, expected, rtol=0, atol=0)
263266

@@ -288,7 +291,7 @@ def test_barrier_cross_rank(barrier_type, op, mode, num_barriers, N, rounds=3):
288291
)
289292

290293
shmem = iris.iris(1 << 20)
291-
shmem.barrier()
294+
_call_barrier(shmem, barrier_type)
292295
rank = shmem.get_rank()
293296
num_ranks = shmem.get_num_ranks()
294297
heap_bases = shmem.get_heap_bases()
@@ -332,12 +335,13 @@ def test_barrier_cross_rank(barrier_type, op, mode, num_barriers, N, rounds=3):
332335
result,
333336
)
334337
finally:
335-
shmem.barrier()
338+
_call_barrier(shmem, barrier_type)
336339
del shmem
337340
gc.collect()
338341

339342

340-
def test_barrier_timeout_assert():
343+
@pytest.mark.parametrize("barrier_type", BARRIER_TYPES)
344+
def test_barrier_timeout_assert(barrier_type):
341345
"""Verify device_barrier asserts on timeout instead of hanging forever.
342346
343347
Only rank 0 calls the barrier kernel. Other ranks skip it, so rank 0
@@ -351,7 +355,7 @@ def test_barrier_timeout_assert():
351355
if num_ranks < 2:
352356
pytest.skip("Need at least 2 ranks")
353357

354-
shmem.barrier()
358+
_call_barrier(shmem, barrier_type)
355359

356360
flags = shmem._device_barrier_state.setdefault(None, shmem.zeros((num_ranks,), dtype=torch.int32))
357361

@@ -370,6 +374,8 @@ def test_barrier_timeout_assert():
370374
with pytest.raises(RuntimeError, match="device-side assert"):
371375
torch.cuda.synchronize()
372376
finally:
373-
shmem.barrier()
377+
# No barrier here: rank 0's GPU is dead after the intentional
378+
# device-side assert. Any GPU sync (NCCL or device_barrier)
379+
# will hang or crash.
374380
del shmem
375381
gc.collect()

0 commit comments

Comments
 (0)