|
| 1 | +# SPDX-License-Identifier: MIT |
| 2 | +# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. |
| 3 | + |
| 4 | +""" |
| 5 | +All-Gather communication primitive using Iris. |
| 6 | +
|
| 7 | +This module provides an all-gather operation along the M dimension using |
| 8 | +GPU-initiated communication via the Iris library. |
| 9 | +""" |
| 10 | + |
| 11 | +import torch |
| 12 | +from torch import Tensor |
| 13 | +import triton |
| 14 | +import triton.language as tl |
| 15 | +import logging |
| 16 | + |
| 17 | +import iris |
| 18 | + |
| 19 | +# If we got here, iris is available |
| 20 | +IRIS_AVAILABLE = True |
| 21 | + |
| 22 | +logger = logging.getLogger("aiter") |
| 23 | + |
| 24 | + |
| 25 | +@triton.jit |
| 26 | +def _all_gather_impl( |
| 27 | + pid, |
| 28 | + shard_ptr, |
| 29 | + out_ptr, |
| 30 | + M, |
| 31 | + M_shard, |
| 32 | + N, |
| 33 | + stride_sm, |
| 34 | + stride_sn, |
| 35 | + stride_om, |
| 36 | + stride_on, |
| 37 | + cur_rank: tl.constexpr, |
| 38 | + world_size: tl.constexpr, |
| 39 | + heap_bases: tl.tensor, |
| 40 | + BLOCK_M: tl.constexpr, |
| 41 | + BLOCK_N: tl.constexpr, |
| 42 | + GROUP_SIZE_M: tl.constexpr, |
| 43 | + NUM_SMS: tl.constexpr, |
| 44 | +): |
| 45 | + """ |
| 46 | + Shared all-gather implementation using push-based approach with iris.put. 1D persistent-style PID mapping |
| 47 | +
|
| 48 | + Each rank sends its (M_shard)×N to all other ranks at the appropriate offset. |
| 49 | +
|
| 50 | + Args: |
| 51 | + pid: Program ID, 1D persistent-style PID mapping |
| 52 | + from tl.program_id(0) or passed from parent kernel |
| 53 | + """ |
| 54 | + num_pid_m = tl.cdiv(M_shard, BLOCK_M) |
| 55 | + num_pid_n = tl.cdiv(N, BLOCK_N) |
| 56 | + total_tiles = num_pid_m * num_pid_n |
| 57 | + |
| 58 | + # Persistent loop over tiles |
| 59 | + for tile_id in range(pid, total_tiles, NUM_SMS): |
| 60 | + # Swizzle pattern |
| 61 | + num_pid_in_group = GROUP_SIZE_M * num_pid_n |
| 62 | + group_id = tile_id // num_pid_in_group |
| 63 | + first_pid_m = group_id * GROUP_SIZE_M |
| 64 | + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) |
| 65 | + pid_m = first_pid_m + ((tile_id % num_pid_in_group) % group_size_m) |
| 66 | + pid_n = (tile_id % num_pid_in_group) // group_size_m |
| 67 | + |
| 68 | + # Local indices |
| 69 | + rm_local = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) |
| 70 | + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) |
| 71 | + rm_local = tl.max_contiguous(tl.multiple_of(rm_local, BLOCK_M), BLOCK_M) |
| 72 | + rn = tl.max_contiguous(tl.multiple_of(rn, BLOCK_N), BLOCK_N) |
| 73 | + mask_m_local = rm_local < M_shard |
| 74 | + mask_n = rn < N |
| 75 | + |
| 76 | + # Load local shard |
| 77 | + shard_ptrs = shard_ptr + rm_local[:, None] * stride_sm + rn[None, :] * stride_sn |
| 78 | + shard_data = tl.load( |
| 79 | + shard_ptrs, mask=mask_m_local[:, None] & mask_n[None, :], other=0.0 |
| 80 | + ) |
| 81 | + |
| 82 | + # Send to all ranks at the appropriate M offset |
| 83 | + for dst in range(world_size): |
| 84 | + # Calculate global M indices |
| 85 | + rm_global = cur_rank * M_shard + rm_local |
| 86 | + mask_m_global = rm_global < M |
| 87 | + final_mask = mask_m_global[:, None] & mask_n[None, :] |
| 88 | + |
| 89 | + out_ptrs = ( |
| 90 | + out_ptr + rm_global[:, None] * stride_om + rn[None, :] * stride_on |
| 91 | + ) |
| 92 | + |
| 93 | + if dst == cur_rank: |
| 94 | + # Local store |
| 95 | + tl.store(out_ptrs, shard_data, mask=final_mask) |
| 96 | + else: |
| 97 | + # Remote store using iris.put |
| 98 | + # from_ptr: local source, to_ptr: remote destination |
| 99 | + iris.put( |
| 100 | + shard_ptr + rm_local[:, None] * stride_sm + rn[None, :] * stride_sn, |
| 101 | + out_ptrs, |
| 102 | + cur_rank, |
| 103 | + dst, |
| 104 | + heap_bases, |
| 105 | + mask=final_mask, |
| 106 | + ) |
| 107 | + |
| 108 | + |
| 109 | +@triton.jit |
| 110 | +def _all_gather_kernel( |
| 111 | + shard_ptr, # *[M_shard, N] |
| 112 | + out_ptr, # *[M, N] |
| 113 | + M, |
| 114 | + M_shard, |
| 115 | + N, |
| 116 | + stride_sm, |
| 117 | + stride_sn, |
| 118 | + stride_om, |
| 119 | + stride_on, |
| 120 | + cur_rank: tl.constexpr, |
| 121 | + world_size: tl.constexpr, |
| 122 | + heap_bases: tl.tensor, |
| 123 | + BLOCK_M: tl.constexpr, |
| 124 | + BLOCK_N: tl.constexpr, |
| 125 | + GROUP_SIZE_M: tl.constexpr, |
| 126 | + NUM_SMS: tl.constexpr, |
| 127 | +): |
| 128 | + """ |
| 129 | + All-gather kernel entry point. |
| 130 | +
|
| 131 | + This is a wrapper around _all_gather_impl that gets the program ID. |
| 132 | + """ |
| 133 | + pid = tl.program_id(0) |
| 134 | + _all_gather_impl( |
| 135 | + pid, |
| 136 | + shard_ptr, |
| 137 | + out_ptr, |
| 138 | + M, |
| 139 | + M_shard, |
| 140 | + N, |
| 141 | + stride_sm, |
| 142 | + stride_sn, |
| 143 | + stride_om, |
| 144 | + stride_on, |
| 145 | + cur_rank, |
| 146 | + world_size, |
| 147 | + heap_bases, |
| 148 | + BLOCK_M, |
| 149 | + BLOCK_N, |
| 150 | + GROUP_SIZE_M, |
| 151 | + NUM_SMS, |
| 152 | + ) |
| 153 | + |
| 154 | + |
| 155 | +def all_gather( |
| 156 | + input_shard: Tensor, |
| 157 | + ctx: "IrisCommContext" = None, |
| 158 | + block_m: int = 64, |
| 159 | + block_n: int = 64, |
| 160 | + group_size_m: int = 8, |
| 161 | + num_sms: int = 256, |
| 162 | +) -> Tensor: |
| 163 | + """ |
| 164 | + Perform all-gather along the M (row) dimension. |
| 165 | +
|
| 166 | + This operation: |
| 167 | + 1. Each rank has a shard of shape [M_shard, N] |
| 168 | + 2. All ranks send their shards to all other ranks |
| 169 | + 3. Each rank receives a full tensor of shape [M, N] where M = M_shard * world_size |
| 170 | +
|
| 171 | + Args: |
| 172 | + input_shard (Tensor): Input shard of shape [M_shard, N] in Iris shared memory |
| 173 | + ctx (IrisCommContext): Iris communication context. Optional if global context exists. |
| 174 | + block_m (int): Block size for M dimension. Default: 64 |
| 175 | + block_n (int): Block size for N dimension. Default: 64 |
| 176 | + group_size_m (int): Group size for swizzling. Default: 8 |
| 177 | + num_sms (int): Number of SMs to use (persistent kernel). Default: 256 |
| 178 | +
|
| 179 | + Returns: |
| 180 | + Tensor: Full tensor of shape [M, N] where M = M_shard * world_size |
| 181 | +
|
| 182 | + Example: |
| 183 | + >>> with IrisCommContext() as ctx: |
| 184 | + >>> input_shard = ctx.iris_ctx.zeros((1024, 7168), dtype=torch.float32) |
| 185 | + >>> # ... initialize input_shard ... |
| 186 | + >>> full_tensor = all_gather(input_shard, ctx) |
| 187 | + >>> print(full_tensor.shape) # [8192, 7168] for world_size=8 |
| 188 | + """ |
| 189 | + if not IRIS_AVAILABLE: |
| 190 | + raise RuntimeError("Iris library is not available. Cannot perform all-gather.") |
| 191 | + |
| 192 | + if not ctx.is_initialized: |
| 193 | + raise RuntimeError( |
| 194 | + "Iris context not initialized. Use IrisCommContext as context manager." |
| 195 | + ) |
| 196 | + |
| 197 | + # Get distributed parameters from context |
| 198 | + cur_rank = ctx.cur_rank |
| 199 | + world_size = ctx.num_ranks |
| 200 | + heap_bases = ctx.get_heap_bases() |
| 201 | + iris_ctx = ctx.iris_ctx |
| 202 | + |
| 203 | + # Input shape |
| 204 | + M_shard, N = input_shard.shape |
| 205 | + M = M_shard * world_size |
| 206 | + |
| 207 | + logger.info( |
| 208 | + f"Rank {cur_rank}/{world_size}: All-gather M_shard={M_shard}, N={N} -> M={M}" |
| 209 | + ) |
| 210 | + |
| 211 | + # Allocate output buffer in IRIS shared memory |
| 212 | + full_output = iris_ctx.zeros((M, N), dtype=input_shard.dtype) |
| 213 | + |
| 214 | + # Launch kernel |
| 215 | + grid = (num_sms,) |
| 216 | + _all_gather_kernel[grid]( |
| 217 | + input_shard, |
| 218 | + full_output, |
| 219 | + M, |
| 220 | + M_shard, |
| 221 | + N, |
| 222 | + input_shard.stride(0), |
| 223 | + input_shard.stride(1), |
| 224 | + full_output.stride(0), |
| 225 | + full_output.stride(1), |
| 226 | + cur_rank, |
| 227 | + world_size, |
| 228 | + heap_bases, |
| 229 | + BLOCK_M=block_m, |
| 230 | + BLOCK_N=block_n, |
| 231 | + GROUP_SIZE_M=group_size_m, |
| 232 | + NUM_SMS=num_sms, |
| 233 | + num_warps=16, |
| 234 | + num_stages=4, |
| 235 | + waves_per_eu=4, |
| 236 | + ) |
| 237 | + |
| 238 | + # Synchronize |
| 239 | + torch.cuda.synchronize() |
| 240 | + iris_ctx.barrier() |
| 241 | + |
| 242 | + logger.info( |
| 243 | + f"Rank {cur_rank}: All-gather complete, output shape: {full_output.shape}" |
| 244 | + ) |
| 245 | + |
| 246 | + return full_output |
0 commit comments