Skip to content

Commit b9f8264

Browse files
committed
add device barrier
1 parent 2afde85 commit b9f8264

File tree

4 files changed

+531
-76
lines changed

4 files changed

+531
-76
lines changed

iris/_distributed_helpers.py

Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
# SPDX-License-Identifier: MIT
22
# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved.
33

4+
45
import torch
56
import torch.distributed as dist
67
import numpy as np
8+
import triton
9+
import triton.language as tl
710

811

912
def _infer_device():
@@ -207,6 +210,77 @@ def distributed_broadcast_tensor(value_to_broadcast=None, root=0):
207210
return obj[0]
208211

209212

213+
def extract_group_info(group, rank, num_ranks):
214+
"""
215+
Extract rank and stride information for a process group.
216+
217+
Args:
218+
group: ProcessGroup or None. If None, uses the provided rank/num_ranks
219+
as the default (all-ranks) group.
220+
rank: Global rank of the current process.
221+
num_ranks: Total number of ranks in the default group.
222+
223+
Returns:
224+
Tuple of (rank_in_group, rank_global, world_size, rank_start, rank_stride):
225+
- rank_in_group: Rank within the group (0-indexed)
226+
- rank_global: Global rank of this process
227+
- world_size: Number of ranks in the group
228+
- rank_start: Starting global rank of the group
229+
- rank_stride: Stride between consecutive ranks in the group
230+
231+
Examples:
232+
>>> # group=None: all ranks [0,1,2,3], current global rank is 2
233+
>>> extract_group_info(None, 2, 4)
234+
(2, 2, 4, 0, 1)
235+
236+
>>> # DP group: strided ranks [0,4,8,12], current global rank is 8
237+
>>> extract_group_info(dp_group, 8, 16)
238+
(2, 8, 4, 0, 4)
239+
"""
240+
if group is None:
241+
return rank, rank, num_ranks, 0, 1
242+
243+
if not dist.is_initialized():
244+
raise RuntimeError(
245+
"torch.distributed must be initialized to use ProcessGroup. "
246+
"Call torch.distributed.init_process_group() first."
247+
)
248+
249+
group_ranks = dist.get_process_group_ranks(group)
250+
world_size = len(group_ranks)
251+
rank_global = dist.get_rank()
252+
253+
if rank_global not in group_ranks:
254+
raise RuntimeError(
255+
f"Current rank {rank_global} is not part of the specified process group. "
256+
f"Group contains ranks: {group_ranks}"
257+
)
258+
259+
rank_in_group = group_ranks.index(rank_global)
260+
261+
if len(group_ranks) > 1:
262+
strides = [group_ranks[i] - group_ranks[i - 1] for i in range(1, len(group_ranks))]
263+
if not all(s == strides[0] for s in strides):
264+
raise NotImplementedError(
265+
f"Non-strided process groups are not yet supported. "
266+
f"Group ranks: {group_ranks}. "
267+
f"Please use groups with uniform stride (e.g., [0,1,2,3] or [0,4,8,12])."
268+
)
269+
rank_start = group_ranks[0]
270+
rank_stride = strides[0]
271+
if rank_stride == 0:
272+
raise ValueError(
273+
f"Invalid process group: rank_stride is 0, indicating duplicate ranks. "
274+
f"Group ranks: {group_ranks}. "
275+
f"Each rank must appear exactly once in a process group."
276+
)
277+
else:
278+
rank_start = group_ranks[0]
279+
rank_stride = 1
280+
281+
return rank_in_group, rank_global, world_size, rank_start, rank_stride
282+
283+
210284
def distributed_barrier(group=None):
211285
"""
212286
Synchronization barrier using PyTorch distributed.
@@ -220,6 +294,92 @@ def distributed_barrier(group=None):
220294
dist.barrier(group=group)
221295

222296

297+
@triton.jit
298+
def _translate_ptr(ptr, from_rank, to_rank, heap_bases):
299+
"""Translate a pointer from one rank's address space to another's."""
300+
from_base = tl.load(heap_bases + from_rank)
301+
to_base = tl.load(heap_bases + to_rank)
302+
offset = tl.cast(ptr, tl.uint64) - from_base
303+
translated_ptr = tl.cast(tl.cast(to_base, tl.pointer_type(tl.int8)) + offset, ptr.dtype)
304+
return translated_ptr
305+
306+
307+
@triton.jit
308+
def _device_barrier_kernel(
309+
flags_ptr,
310+
iris_rank,
311+
world_size: tl.constexpr,
312+
rank_start,
313+
rank_stride,
314+
heap_bases,
315+
):
316+
"""
317+
Stateless device-side barrier using atomic operations on the symmetric heap.
318+
319+
Launched with grid=(1,). A single CTA:
320+
1. Atomically increments its own flag (atomic_add, release)
321+
2. Serially polls each remote rank's flag for the same value (acquire)
322+
323+
No CPU-side epoch tracking. Each rank's flag IS the epoch, managed
324+
entirely on the GPU via atomic_add. This makes the barrier safe for
325+
CUDA graph capture: during recording the kernel is just recorded,
326+
during replay all ranks increment together.
327+
"""
328+
# Increment own flag and determine target
329+
own_flag_ptr = flags_ptr + iris_rank
330+
own_translated = _translate_ptr(own_flag_ptr, iris_rank, iris_rank, heap_bases)
331+
old = tl.atomic_add(own_translated, 1, sem="release", scope="sys")
332+
target = old + 1
333+
334+
# Poll each remote rank serially
335+
for i in range(world_size):
336+
remote_rank = rank_start + i * rank_stride
337+
if remote_rank != iris_rank:
338+
remote_flag_ptr = flags_ptr + remote_rank
339+
remote_translated = _translate_ptr(remote_flag_ptr, iris_rank, remote_rank, heap_bases)
340+
while (
341+
tl.atomic_cas(
342+
remote_translated,
343+
target,
344+
target,
345+
sem="acquire",
346+
scope="sys",
347+
)
348+
< target
349+
):
350+
pass
351+
352+
353+
def distributed_device_barrier(flags, group, rank, num_ranks, heap_bases):
354+
"""
355+
Stateless device-side barrier using atomic operations on the symmetric heap.
356+
357+
Unlike ``distributed_barrier`` which uses host-side ``torch.distributed.barrier()``,
358+
this launches a single-CTA Triton kernel that synchronizes via
359+
device-side atomics, making it safe to use during CUDA graph capture.
360+
361+
No CPU-side epoch tracking is needed. Each rank's flag on the symmetric
362+
heap serves as its own epoch counter, managed entirely by the GPU via
363+
atomic_add.
364+
365+
Args:
366+
flags: int32 tensor on symmetric heap, one element per rank.
367+
group: ProcessGroup or None. If None, uses all ranks.
368+
rank: Global rank of this process.
369+
num_ranks: Total number of ranks in the default group.
370+
heap_bases: Tensor of heap base addresses for all ranks.
371+
"""
372+
_, rank_global, world_size, rank_start, rank_stride = extract_group_info(group, rank, num_ranks)
373+
_device_barrier_kernel[(1,)](
374+
flags,
375+
rank_global,
376+
world_size,
377+
rank_start,
378+
rank_stride,
379+
heap_bases,
380+
)
381+
382+
223383
def init_distributed():
224384
"""
225385
Initialize PyTorch distributed and return communicator info.

iris/ccl/utils.py

Lines changed: 5 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from typing import Tuple
1010
import triton
1111
import triton.language as tl
12+
from iris._distributed_helpers import extract_group_info as _extract_group_info
1213

1314

1415
@triton.jit()
@@ -67,83 +68,11 @@ def extract_group_info(group, shmem) -> Tuple[int, int, int, int, int]:
6768
6869
Returns:
6970
Tuple of (rank_in_group, rank_global, world_size, rank_start, rank_stride)
70-
- rank_in_group: Rank within the group (0-indexed), used for tile assignment and comparisons
71-
- rank_global: Global rank of this process, used for iris RMA operations (heap_bases indexing)
71+
- rank_in_group: Rank within the group (0-indexed)
72+
- rank_global: Global rank of this process
7273
- world_size: Number of ranks in the group
7374
- rank_start: Starting global rank of the group
7475
- rank_stride: Stride between consecutive ranks in the group
75-
76-
Examples:
77-
>>> # group=None: all ranks [0,1,2,3], current global rank is 2
78-
>>> extract_group_info(None, shmem)
79-
(2, 2, 4, 0, 1) # rank_in_group=2, rank_global=2, world_size=4, start=0, stride=1
80-
81-
>>> # TP group: consecutive ranks [0,1,2,3], current global rank is 2
82-
>>> extract_group_info(tp_group, shmem)
83-
(2, 2, 4, 0, 1) # rank_in_group=2, rank_global=2, world_size=4, start=0, stride=1
84-
85-
>>> # DP group: strided ranks [0,4,8,12], current global rank is 8
86-
>>> extract_group_info(dp_group, shmem)
87-
(2, 8, 4, 0, 4) # rank_in_group=2, rank_global=8, world_size=4, start=0, stride=4
8876
"""
89-
if group is None:
90-
# Use all ranks in shmem context
91-
# When group is None, rank_in_group equals rank_global
92-
rank_global = shmem.get_rank()
93-
rank_in_group = rank_global
94-
world_size = shmem.get_num_ranks()
95-
rank_start = 0
96-
rank_stride = 1
97-
return rank_in_group, rank_global, world_size, rank_start, rank_stride
98-
99-
# Extract from ProcessGroup
100-
import torch.distributed as dist
101-
102-
if not dist.is_initialized():
103-
raise RuntimeError(
104-
"torch.distributed must be initialized to use ProcessGroup. "
105-
"Call torch.distributed.init_process_group() first."
106-
)
107-
108-
group_ranks = dist.get_process_group_ranks(group)
109-
world_size = len(group_ranks)
110-
rank_global = dist.get_rank()
111-
112-
if rank_global not in group_ranks:
113-
raise RuntimeError(
114-
f"Current rank {rank_global} is not part of the specified process group. "
115-
f"Group contains ranks: {group_ranks}"
116-
)
117-
118-
rank_in_group = group_ranks.index(rank_global)
119-
120-
# Detect stride pattern
121-
if len(group_ranks) > 1:
122-
# Check if all consecutive pairs have the same stride
123-
strides = [group_ranks[i] - group_ranks[i - 1] for i in range(1, len(group_ranks))]
124-
is_strided = all(s == strides[0] for s in strides)
125-
126-
if is_strided:
127-
rank_start = group_ranks[0]
128-
rank_stride = strides[0]
129-
130-
# Validate rank_stride is not zero (would indicate duplicate ranks)
131-
if rank_stride == 0:
132-
raise ValueError(
133-
f"Invalid process group: rank_stride is 0, indicating duplicate ranks. "
134-
f"Group ranks: {group_ranks}. "
135-
f"Each rank must appear exactly once in a process group."
136-
)
137-
else:
138-
# Non-strided group - not supported yet
139-
raise NotImplementedError(
140-
f"Non-strided process groups are not yet supported. "
141-
f"Group ranks: {group_ranks}. "
142-
f"Please use groups with uniform stride (e.g., [0,1,2,3] or [0,4,8,12])."
143-
)
144-
else:
145-
# Single rank group
146-
rank_start = group_ranks[0]
147-
rank_stride = 1
148-
149-
return rank_in_group, rank_global, world_size, rank_start, rank_stride
77+
78+
return _extract_group_info(group, shmem.get_rank(), shmem.get_num_ranks())

iris/iris.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
from iris._distributed_helpers import (
4646
init_distributed,
4747
distributed_barrier,
48+
distributed_device_barrier,
4849
distributed_broadcast_scalar,
4950
distributed_broadcast_tensor,
5051
)
@@ -55,6 +56,7 @@
5556
)
5657
from iris.symmetric_heap import SymmetricHeap
5758
import numpy as np
59+
from typing import Any
5860
import torch
5961
import logging
6062

@@ -135,6 +137,9 @@ def __init__(self, heap_size=1 << 30, allocator_type="torch"):
135137
# Lazy initialization for ops interface
136138
self._ops = None
137139

140+
# Device-side barrier state, keyed by process group (None = all ranks).
141+
self._device_barrier_state: dict[Any, torch.Tensor] = {}
142+
138143
# Initialize tracing
139144
self.tracing = Tracing(self)
140145

@@ -989,6 +994,34 @@ def barrier(self, stream=None, group=None):
989994
# Distributed barrier
990995
distributed_barrier(group=group)
991996

997+
def device_barrier(self, group=None):
998+
"""
999+
Stateless device-side barrier that is CUDA graph capturable.
1000+
1001+
Unlike ``barrier()`` which uses host-side ``torch.distributed.barrier()``,
1002+
this uses device-side atomic operations on the symmetric heap to synchronize
1003+
ranks. No CPU-side epoch tracking -- each rank's flag on the heap serves
1004+
as its own epoch counter, managed entirely by the GPU via atomic_add.
1005+
1006+
Args:
1007+
group (ProcessGroup, optional): The process group to synchronize.
1008+
If None, uses all ranks in the shmem context.
1009+
1010+
Example:
1011+
>>> ctx = iris.iris(1 << 20)
1012+
>>> ctx.device_barrier() # Synchronize all ranks on device
1013+
"""
1014+
if group not in self._device_barrier_state:
1015+
self._device_barrier_state[group] = self.zeros((self.num_ranks,), dtype=torch.int32)
1016+
1017+
distributed_device_barrier(
1018+
self._device_barrier_state[group],
1019+
group,
1020+
self.cur_rank,
1021+
self.num_ranks,
1022+
self.get_heap_bases(),
1023+
)
1024+
9921025
def get_device(self):
9931026
"""
9941027
Get the underlying device where the Iris symmetric heap resides.

0 commit comments

Comments
 (0)