11# SPDX-License-Identifier: MIT
22# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved.
33
4+
45import torch
56import torch .distributed as dist
67import numpy as np
8+ import triton
9+ import triton .language as tl
710
811
912def _infer_device ():
@@ -207,6 +210,77 @@ def distributed_broadcast_tensor(value_to_broadcast=None, root=0):
207210 return obj [0 ]
208211
209212
213+ def extract_group_info (group , rank , num_ranks ):
214+ """
215+ Extract rank and stride information for a process group.
216+
217+ Args:
218+ group: ProcessGroup or None. If None, uses the provided rank/num_ranks
219+ as the default (all-ranks) group.
220+ rank: Global rank of the current process.
221+ num_ranks: Total number of ranks in the default group.
222+
223+ Returns:
224+ Tuple of (rank_in_group, rank_global, world_size, rank_start, rank_stride):
225+ - rank_in_group: Rank within the group (0-indexed)
226+ - rank_global: Global rank of this process
227+ - world_size: Number of ranks in the group
228+ - rank_start: Starting global rank of the group
229+ - rank_stride: Stride between consecutive ranks in the group
230+
231+ Examples:
232+ >>> # group=None: all ranks [0,1,2,3], current global rank is 2
233+ >>> extract_group_info(None, 2, 4)
234+ (2, 2, 4, 0, 1)
235+
236+ >>> # DP group: strided ranks [0,4,8,12], current global rank is 8
237+ >>> extract_group_info(dp_group, 8, 16)
238+ (2, 8, 4, 0, 4)
239+ """
240+ if group is None :
241+ return rank , rank , num_ranks , 0 , 1
242+
243+ if not dist .is_initialized ():
244+ raise RuntimeError (
245+ "torch.distributed must be initialized to use ProcessGroup. "
246+ "Call torch.distributed.init_process_group() first."
247+ )
248+
249+ group_ranks = dist .get_process_group_ranks (group )
250+ world_size = len (group_ranks )
251+ rank_global = dist .get_rank ()
252+
253+ if rank_global not in group_ranks :
254+ raise RuntimeError (
255+ f"Current rank { rank_global } is not part of the specified process group. "
256+ f"Group contains ranks: { group_ranks } "
257+ )
258+
259+ rank_in_group = group_ranks .index (rank_global )
260+
261+ if len (group_ranks ) > 1 :
262+ strides = [group_ranks [i ] - group_ranks [i - 1 ] for i in range (1 , len (group_ranks ))]
263+ if not all (s == strides [0 ] for s in strides ):
264+ raise NotImplementedError (
265+ f"Non-strided process groups are not yet supported. "
266+ f"Group ranks: { group_ranks } . "
267+ f"Please use groups with uniform stride (e.g., [0,1,2,3] or [0,4,8,12])."
268+ )
269+ rank_start = group_ranks [0 ]
270+ rank_stride = strides [0 ]
271+ if rank_stride == 0 :
272+ raise ValueError (
273+ f"Invalid process group: rank_stride is 0, indicating duplicate ranks. "
274+ f"Group ranks: { group_ranks } . "
275+ f"Each rank must appear exactly once in a process group."
276+ )
277+ else :
278+ rank_start = group_ranks [0 ]
279+ rank_stride = 1
280+
281+ return rank_in_group , rank_global , world_size , rank_start , rank_stride
282+
283+
210284def distributed_barrier (group = None ):
211285 """
212286 Synchronization barrier using PyTorch distributed.
@@ -220,6 +294,92 @@ def distributed_barrier(group=None):
220294 dist .barrier (group = group )
221295
222296
297+ @triton .jit
298+ def _translate_ptr (ptr , from_rank , to_rank , heap_bases ):
299+ """Translate a pointer from one rank's address space to another's."""
300+ from_base = tl .load (heap_bases + from_rank )
301+ to_base = tl .load (heap_bases + to_rank )
302+ offset = tl .cast (ptr , tl .uint64 ) - from_base
303+ translated_ptr = tl .cast (tl .cast (to_base , tl .pointer_type (tl .int8 )) + offset , ptr .dtype )
304+ return translated_ptr
305+
306+
307+ @triton .jit
308+ def _device_barrier_kernel (
309+ flags_ptr ,
310+ iris_rank ,
311+ world_size : tl .constexpr ,
312+ rank_start ,
313+ rank_stride ,
314+ heap_bases ,
315+ ):
316+ """
317+ Stateless device-side barrier using atomic operations on the symmetric heap.
318+
319+ Launched with grid=(1,). A single CTA:
320+ 1. Atomically increments its own flag (atomic_add, release)
321+ 2. Serially polls each remote rank's flag for the same value (acquire)
322+
323+ No CPU-side epoch tracking. Each rank's flag IS the epoch, managed
324+ entirely on the GPU via atomic_add. This makes the barrier safe for
325+ CUDA graph capture: during recording the kernel is just recorded,
326+ during replay all ranks increment together.
327+ """
328+ # Increment own flag and determine target
329+ own_flag_ptr = flags_ptr + iris_rank
330+ own_translated = _translate_ptr (own_flag_ptr , iris_rank , iris_rank , heap_bases )
331+ old = tl .atomic_add (own_translated , 1 , sem = "release" , scope = "sys" )
332+ target = old + 1
333+
334+ # Poll each remote rank serially
335+ for i in range (world_size ):
336+ remote_rank = rank_start + i * rank_stride
337+ if remote_rank != iris_rank :
338+ remote_flag_ptr = flags_ptr + remote_rank
339+ remote_translated = _translate_ptr (remote_flag_ptr , iris_rank , remote_rank , heap_bases )
340+ while (
341+ tl .atomic_cas (
342+ remote_translated ,
343+ target ,
344+ target ,
345+ sem = "acquire" ,
346+ scope = "sys" ,
347+ )
348+ < target
349+ ):
350+ pass
351+
352+
353+ def distributed_device_barrier (flags , group , rank , num_ranks , heap_bases ):
354+ """
355+ Stateless device-side barrier using atomic operations on the symmetric heap.
356+
357+ Unlike ``distributed_barrier`` which uses host-side ``torch.distributed.barrier()``,
358+ this launches a single-CTA Triton kernel that synchronizes via
359+ device-side atomics, making it safe to use during CUDA graph capture.
360+
361+ No CPU-side epoch tracking is needed. Each rank's flag on the symmetric
362+ heap serves as its own epoch counter, managed entirely by the GPU via
363+ atomic_add.
364+
365+ Args:
366+ flags: int32 tensor on symmetric heap, one element per rank.
367+ group: ProcessGroup or None. If None, uses all ranks.
368+ rank: Global rank of this process.
369+ num_ranks: Total number of ranks in the default group.
370+ heap_bases: Tensor of heap base addresses for all ranks.
371+ """
372+ _ , rank_global , world_size , rank_start , rank_stride = extract_group_info (group , rank , num_ranks )
373+ _device_barrier_kernel [(1 ,)](
374+ flags ,
375+ rank_global ,
376+ world_size ,
377+ rank_start ,
378+ rank_stride ,
379+ heap_bases ,
380+ )
381+
382+
223383def init_distributed ():
224384 """
225385 Initialize PyTorch distributed and return communicator info.
0 commit comments