Skip to content

Commit 4d1cc6f

Browse files
committed
address copilot feedback
1 parent 2080f13 commit 4d1cc6f

File tree

3 files changed

+48
-80
lines changed

3 files changed

+48
-80
lines changed

iris/_distributed_helpers.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -248,12 +248,11 @@ def extract_group_info(group, rank, num_ranks):
248248

249249
group_ranks = dist.get_process_group_ranks(group)
250250
world_size = len(group_ranks)
251-
rank_global = dist.get_rank()
251+
rank_global = rank
252252

253253
if rank_global not in group_ranks:
254254
raise RuntimeError(
255-
f"Current rank {rank_global} is not part of the specified process group. "
256-
f"Group contains ranks: {group_ranks}"
255+
f"Rank {rank_global} is not part of the specified process group. Group contains ranks: {group_ranks}"
257256
)
258257

259258
rank_in_group = group_ranks.index(rank_global)
@@ -315,16 +314,17 @@ def _device_barrier_kernel(
315314
MAX_SPINS: tl.constexpr = 1_000_000_000,
316315
):
317316
"""
318-
Stateless device-side barrier using atomic operations on the symmetric heap.
317+
Device-side barrier using atomic operations on the symmetric heap.
318+
CUDA graph capturable.
319+
320+
Stateless w.r.t. host-side epoch tracking: there is no CPU-side epoch
321+
counter. Each rank's flag on the heap serves as its own epoch counter,
322+
managed entirely by the GPU via atomic_add. A persistent per-group flags
323+
tensor is cached in ``_device_barrier_state``.
319324
320325
Launched with grid=(1,). A single CTA:
321326
1. Atomically increments its own flag (atomic_add, release)
322327
2. Serially polls each remote rank's flag for the same value (acquire)
323-
324-
No CPU-side epoch tracking. Each rank's flag IS the epoch, managed
325-
entirely on the GPU via atomic_add. This makes the barrier safe for
326-
CUDA graph capture: during recording the kernel is just recorded,
327-
during replay all ranks increment together.
328328
"""
329329
# Increment own flag and determine target
330330
own_flag_ptr = flags_ptr + iris_rank
@@ -355,15 +355,17 @@ def _device_barrier_kernel(
355355

356356
def distributed_device_barrier(flags, group, rank, num_ranks, heap_bases):
357357
"""
358-
Stateless device-side barrier using atomic operations on the symmetric heap.
358+
Device-side barrier using atomic operations on the symmetric heap.
359+
CUDA graph capturable.
359360
360361
Unlike ``distributed_barrier`` which uses host-side ``torch.distributed.barrier()``,
361362
this launches a single-CTA Triton kernel that synchronizes via
362363
device-side atomics, making it safe to use during CUDA graph capture.
363364
364-
No CPU-side epoch tracking is needed. Each rank's flag on the symmetric
365-
heap serves as its own epoch counter, managed entirely by the GPU via
366-
atomic_add.
365+
Stateless w.r.t. host-side epoch tracking: each rank's flag on the
366+
symmetric heap serves as its own epoch counter, managed entirely by
367+
the GPU via atomic_add. A persistent per-group flags tensor is cached
368+
in ``_device_barrier_state``.
367369
368370
Args:
369371
flags: int32 tensor on symmetric heap, one element per rank.

iris/iris.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -996,12 +996,14 @@ def barrier(self, stream=None, group=None):
996996

997997
def device_barrier(self, group=None):
998998
"""
999-
Stateless device-side barrier that is CUDA graph capturable.
999+
Device-side barrier that is CUDA graph capturable.
10001000
10011001
Unlike ``barrier()`` which uses host-side ``torch.distributed.barrier()``,
10021002
this uses device-side atomic operations on the symmetric heap to synchronize
1003-
ranks. No CPU-side epoch tracking -- each rank's flag on the heap serves
1004-
as its own epoch counter, managed entirely by the GPU via atomic_add.
1003+
ranks. Stateless w.r.t. host-side epoch tracking: each rank's flag on
1004+
the heap serves as its own epoch counter, managed entirely by the GPU
1005+
via atomic_add. A persistent per-group flags tensor is cached in
1006+
``_device_barrier_state``.
10051007
10061008
Args:
10071009
group (ProcessGroup, optional): The process group to synchronize.

tests/unittests/test_barriers.py

Lines changed: 28 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
import triton
1010
import triton.language as tl
1111
import iris
12-
from iris._distributed_helpers import _device_barrier_kernel, extract_group_info
1312

1413

1514
BarrierType = Literal["host", "device"]
@@ -55,22 +54,23 @@ def _write_remote_kernel(
5554
@pytest.mark.parametrize("barrier_type", BARRIER_TYPES)
5655
def test_barrier_basic(barrier_type, n):
5756
shmem = iris.iris(1 << 20)
58-
shmem.barrier()
57+
_call_barrier(shmem, barrier_type)
5958

6059
try:
6160
for _ in range(n):
6261
_call_barrier(shmem, barrier_type)
6362
finally:
64-
shmem.barrier()
63+
_call_barrier(shmem, barrier_type)
6564
del shmem
6665
gc.collect()
6766

6867

6968
@pytest.mark.parametrize("n", [1, 2, 5, 10])
70-
def test_barrier_state_reuse(n):
69+
@pytest.mark.parametrize("barrier_type", BARRIER_TYPES)
70+
def test_barrier_state_reuse(barrier_type, n):
7171
"""Verify device barrier reuses the same flags tensor across calls."""
7272
shmem = iris.iris(1 << 20)
73-
shmem.barrier()
73+
_call_barrier(shmem, barrier_type)
7474

7575
try:
7676
shmem.device_barrier()
@@ -82,7 +82,7 @@ def test_barrier_state_reuse(n):
8282
shmem.device_barrier()
8383
assert shmem._device_barrier_state[None].data_ptr() == flags_ptr
8484
finally:
85-
shmem.barrier()
85+
_call_barrier(shmem, barrier_type)
8686
del shmem
8787
gc.collect()
8888

@@ -161,13 +161,13 @@ def _cross_rank_graph(
161161
buf,
162162
result,
163163
):
164-
stream = torch.cuda.Stream()
164+
capture_stream = torch.cuda.Stream()
165165

166166
if op == "load":
167167
buf.fill_(float(rank))
168168

169169
# Warmup on capture stream.
170-
with torch.cuda.stream(stream):
170+
with torch.cuda.stream(capture_stream):
171171
for _ in range(num_barriers):
172172
shmem.device_barrier()
173173
_read_remote_kernel[(1,)](
@@ -180,11 +180,11 @@ def _cross_rank_graph(
180180
)
181181
for _ in range(num_barriers):
182182
shmem.device_barrier()
183-
stream.synchronize()
183+
capture_stream.synchronize()
184184

185185
# Capture.
186186
graph = torch.cuda.CUDAGraph()
187-
with torch.cuda.graph(graph, stream=stream):
187+
with torch.cuda.graph(graph, stream=capture_stream):
188188
for _ in range(num_barriers):
189189
shmem.device_barrier()
190190
_read_remote_kernel[(1,)](
@@ -201,11 +201,11 @@ def _cross_rank_graph(
201201
# Replay with fresh data.
202202
for i in range(rounds):
203203
val = float(rank + (i + 1) * 10)
204-
buf.fill_(val)
205-
shmem.device_barrier()
206-
207-
graph.replay()
208-
stream.synchronize()
204+
with torch.cuda.stream(capture_stream):
205+
buf.fill_(val)
206+
shmem.device_barrier()
207+
graph.replay()
208+
capture_stream.synchronize()
209209

210210
expected = torch.full(
211211
(N,),
@@ -218,7 +218,7 @@ def _cross_rank_graph(
218218
buf.fill_(0.0)
219219

220220
# Warmup on capture stream.
221-
with torch.cuda.stream(stream):
221+
with torch.cuda.stream(capture_stream):
222222
for _ in range(num_barriers):
223223
shmem.device_barrier()
224224
_write_remote_kernel[(1,)](
@@ -231,11 +231,11 @@ def _cross_rank_graph(
231231
)
232232
for _ in range(num_barriers):
233233
shmem.device_barrier()
234-
stream.synchronize()
234+
capture_stream.synchronize()
235235

236236
# Capture.
237237
graph = torch.cuda.CUDAGraph()
238-
with torch.cuda.graph(graph, stream=stream):
238+
with torch.cuda.graph(graph, stream=capture_stream):
239239
for _ in range(num_barriers):
240240
shmem.device_barrier()
241241
_write_remote_kernel[(1,)](
@@ -251,13 +251,15 @@ def _cross_rank_graph(
251251

252252
# Replay and verify.
253253
for _ in range(rounds):
254-
buf.fill_(0.0)
255-
shmem.device_barrier()
256-
257-
graph.replay()
258-
stream.synchronize()
254+
with torch.cuda.stream(capture_stream):
255+
buf.fill_(0.0)
256+
shmem.device_barrier()
257+
graph.replay()
258+
capture_stream.synchronize()
259259

260-
shmem.device_barrier()
260+
with torch.cuda.stream(capture_stream):
261+
shmem.device_barrier()
262+
capture_stream.synchronize()
261263
expected = torch.full((N,), float(writer), dtype=torch.float32, device="cuda")
262264
torch.testing.assert_close(buf, expected, rtol=0, atol=0)
263265

@@ -288,7 +290,7 @@ def test_barrier_cross_rank(barrier_type, op, mode, num_barriers, N, rounds=3):
288290
)
289291

290292
shmem = iris.iris(1 << 20)
291-
shmem.barrier()
293+
_call_barrier(shmem, barrier_type)
292294
rank = shmem.get_rank()
293295
num_ranks = shmem.get_num_ranks()
294296
heap_bases = shmem.get_heap_bases()
@@ -332,44 +334,6 @@ def test_barrier_cross_rank(barrier_type, op, mode, num_barriers, N, rounds=3):
332334
result,
333335
)
334336
finally:
335-
shmem.barrier()
336-
del shmem
337-
gc.collect()
338-
339-
340-
def test_barrier_timeout_assert():
341-
"""Verify device_barrier asserts on timeout instead of hanging forever.
342-
343-
Only rank 0 calls the barrier kernel. Other ranks skip it, so rank 0
344-
spins waiting for them and hits the MAX_SPINS assert.
345-
"""
346-
shmem = iris.iris(1 << 20)
347-
rank = shmem.get_rank()
348-
num_ranks = shmem.get_num_ranks()
349-
heap_bases = shmem.get_heap_bases()
350-
351-
if num_ranks < 2:
352-
pytest.skip("Need at least 2 ranks")
353-
354-
shmem.barrier()
355-
356-
flags = shmem._device_barrier_state.setdefault(None, shmem.zeros((num_ranks,), dtype=torch.int32))
357-
358-
try:
359-
if rank == 0:
360-
_, rank_global, world_size, rank_start, rank_stride = extract_group_info(None, rank, num_ranks)
361-
_device_barrier_kernel[(1,)](
362-
flags,
363-
rank_global,
364-
world_size,
365-
rank_start,
366-
rank_stride,
367-
heap_bases,
368-
MAX_SPINS=1000,
369-
)
370-
with pytest.raises(RuntimeError, match="device-side assert"):
371-
torch.cuda.synchronize()
372-
finally:
373-
shmem.barrier()
337+
_call_barrier(shmem, barrier_type)
374338
del shmem
375339
gc.collect()

0 commit comments

Comments
 (0)