|
| 1 | +""" |
| 2 | +One-Shot All-Reduce Example |
| 3 | +======================================== |
| 4 | +This example demonstrates how to implement a one-shot pulling all-reduce operation |
| 5 | +using Helion and PyTorch's distributed capabilities. It includes a Helion kernel |
| 6 | +demonstrating how to do cross-device synchronization using symmetric memory signal pads |
| 7 | +and access symmetric memory tensor resident on peer devices. |
| 8 | +""" |
| 9 | + |
| 10 | +# %% |
| 11 | +# Imports |
| 12 | +# ------- |
| 13 | +from __future__ import annotations |
| 14 | + |
| 15 | +import os |
| 16 | + |
| 17 | +import torch |
| 18 | +import torch.distributed as dist |
| 19 | +import torch.distributed._symmetric_memory as symm_mem |
| 20 | +from torch.utils.cpp_extension import load_inline |
| 21 | + |
| 22 | +import helion |
| 23 | +import helion.language as hl |
| 24 | + |
| 25 | +# %% |
| 26 | +# TODO(joydddd): work around before symm mem natively supports extract dev_ptrs as tensors: from_blob |
| 27 | +from_blob_cpp = """ |
| 28 | +#include <cuda.h> |
| 29 | +#include <cuda_runtime.h> |
| 30 | +#include <iostream> |
| 31 | +
|
| 32 | +
|
| 33 | +at::Tensor from_blob(uint64_t data_ptr, c10::IntArrayRef sizes, py::object dtype) { |
| 34 | +
|
| 35 | + at::Tensor tensor = at::for_blob((void*)data_ptr, sizes) |
| 36 | + .deleter([](void *ptr) { |
| 37 | + ; |
| 38 | + }) |
| 39 | + .options(at::device(at::kCUDA).dtype(((THPDtype*)dtype.ptr())->scalar_type)) |
| 40 | + .make_tensor(); |
| 41 | +
|
| 42 | + return tensor; |
| 43 | +} |
| 44 | +""" |
| 45 | + |
| 46 | +cpp_mod = load_inline( |
| 47 | + "cpp_mod", cpp_sources=from_blob_cpp, with_cuda=True, functions=["from_blob"] |
| 48 | +) |
| 49 | + |
| 50 | + |
| 51 | +def dev_array_to_tensor_short( |
| 52 | + dev_array_ptr: int, shape: tuple[int], dtype: torch.dtype, device: torch.device |
| 53 | +) -> torch.Tensor: |
| 54 | + """ |
| 55 | + Convert a device array pointer to a PyTorch tensor. |
| 56 | +
|
| 57 | + This is a workaround function that creates a PyTorch tensor from a raw device pointer |
| 58 | + using the C++ extension. It's used to interface with symmetric memory device pointers |
| 59 | + before native support is available. |
| 60 | +
|
| 61 | + Args: |
| 62 | + dev_array_ptr: Raw device pointer as integer |
| 63 | + shape: Shape of the tensor to create |
| 64 | + dtype: PyTorch data type for the tensor |
| 65 | + device: Target device for the tensor |
| 66 | +
|
| 67 | + Returns: |
| 68 | + PyTorch tensor created from the device pointer |
| 69 | + """ |
| 70 | + return cpp_mod.from_blob(dev_array_ptr, shape, dtype) # pyright: ignore[reportAttributeAccessIssue] |
| 71 | + |
| 72 | + |
| 73 | +@helion.jit( |
| 74 | + config=helion.Config( |
| 75 | + block_sizes=[8192], |
| 76 | + num_warps=32, |
| 77 | + ), |
| 78 | + static_shapes=True, |
| 79 | +) |
| 80 | +def one_shot_all_reduce_kernel( |
| 81 | + signal_pad_addrs: torch.Tensor, |
| 82 | + local_signal_pad: torch.Tensor, |
| 83 | + a_shared_tuple: tuple[torch.Tensor, ...], |
| 84 | + my_rank: hl.constexpr, |
| 85 | +) -> torch.Tensor: |
| 86 | + """ |
| 87 | + Helion JIT-compiled kernel for one-shot all-reduce operation. |
| 88 | +
|
| 89 | + This kernel implements a distributed all-reduce using symmetric memory and signal pads |
| 90 | + for cross-device synchronization. It performs element-wise summation across all devices |
| 91 | + in the distributed group using tiled computation for memory efficiency. |
| 92 | +
|
| 93 | + Args: |
| 94 | + signal_pad_addrs: Tensor containing addresses of signal pads for all devices |
| 95 | + local_signal_pad: Local signal pad for synchronization |
| 96 | + a_shared_tuple: Tuple of shared tensors from all devices in the group |
| 97 | + my_rank: Current device's rank in the distributed group |
| 98 | +
|
| 99 | + Returns: |
| 100 | + Tensor containing the all-reduced result (sum across all devices) |
| 101 | + """ |
| 102 | + _, world_size = local_signal_pad.size() |
| 103 | + world_size = hl.specialize(world_size) |
| 104 | + out = torch.empty_like(a_shared_tuple[0]) |
| 105 | + N = out.size(0) |
| 106 | + |
| 107 | + for tile_n in hl.tile(N): |
| 108 | + ptr_tile = signal_pad_addrs[:] |
| 109 | + stack_signalpad = hl.stacktensor_like(local_signal_pad, ptr_tile) |
| 110 | + hl.signal( |
| 111 | + stack_signalpad, |
| 112 | + [tile_n.id, my_rank], |
| 113 | + signal=1, |
| 114 | + wait_for=0, |
| 115 | + scope="sys", |
| 116 | + hasPreviousMemAccess=False, |
| 117 | + ) |
| 118 | + |
| 119 | + for world in hl.tile(world_size, block_size=world_size): |
| 120 | + hl.wait( |
| 121 | + local_signal_pad, |
| 122 | + [tile_n.id, world], |
| 123 | + signal=1, |
| 124 | + update=0, |
| 125 | + scope="sys", |
| 126 | + ) |
| 127 | + |
| 128 | + acc = hl.zeros( |
| 129 | + [tile_n], dtype=a_shared_tuple[0].dtype, device=local_signal_pad.device |
| 130 | + ) |
| 131 | + |
| 132 | + for a in a_shared_tuple: |
| 133 | + acc += a[tile_n] |
| 134 | + |
| 135 | + out[tile_n] = acc |
| 136 | + |
| 137 | + hl.signal( |
| 138 | + stack_signalpad, [tile_n.id, my_rank], signal=1, wait_for=0, scope="sys" |
| 139 | + ) |
| 140 | + |
| 141 | + for world in hl.tile(world_size, block_size=world_size): |
| 142 | + hl.wait( |
| 143 | + local_signal_pad, |
| 144 | + [tile_n.id, world], |
| 145 | + signal=1, |
| 146 | + update=0, |
| 147 | + scope="sys", |
| 148 | + hasSubsequentMemAccess=False, |
| 149 | + ) |
| 150 | + return out |
| 151 | + |
| 152 | + |
| 153 | +def helion_one_shot_all_reduce(a_shared: torch.Tensor) -> torch.Tensor: |
| 154 | + """ |
| 155 | + Prepares symmetric memory tensors for Helion one-shot all-reduce kernel. |
| 156 | + Tracks shared tensors as tuple of tensors, and/or dev_ptrs tensors. |
| 157 | +
|
| 158 | + Args: |
| 159 | + a_shared: Input tensor to be all-reduced across all devices |
| 160 | +
|
| 161 | + Returns: |
| 162 | + Tensor containing the all-reduced result (sum across all devices) |
| 163 | + """ |
| 164 | + assert dist.group.WORLD is not None |
| 165 | + |
| 166 | + symm_mem_hdl = symm_mem.rendezvous(a_shared, group=dist.group.WORLD) |
| 167 | + |
| 168 | + a_shared_tuple = tuple( |
| 169 | + [ |
| 170 | + symm_mem_hdl.get_buffer(i, tuple(a_shared.shape), a_shared.dtype) |
| 171 | + for i in range(symm_mem_hdl.world_size) |
| 172 | + ] |
| 173 | + ) |
| 174 | + |
| 175 | + local_signal_pad = symm_mem_hdl.get_signal_pad( |
| 176 | + symm_mem_hdl.rank, dtype=torch.int32 |
| 177 | + ).view(-1, symm_mem_hdl.world_size) |
| 178 | + |
| 179 | + signal_pad_addrs = dev_array_to_tensor_short( |
| 180 | + symm_mem_hdl.signal_pad_ptrs_dev, |
| 181 | + (symm_mem_hdl.world_size,), |
| 182 | + dtype=torch.uint64, |
| 183 | + device=a_shared.device, |
| 184 | + ) |
| 185 | + |
| 186 | + return one_shot_all_reduce_kernel( |
| 187 | + signal_pad_addrs, |
| 188 | + local_signal_pad, |
| 189 | + a_shared_tuple, |
| 190 | + my_rank=symm_mem_hdl.rank, |
| 191 | + ) |
| 192 | + |
| 193 | + |
| 194 | +def test(N: int, device: torch.device, dtype: torch.dtype) -> None: |
| 195 | + """ |
| 196 | + Test the Helion all-reduce implementation against PyTorch's reference implementation. |
| 197 | + Args: |
| 198 | + N: Total number of elements to test (will be divided by world_size per device) |
| 199 | + device: CUDA device to run the test on |
| 200 | + dtype: Data type for the test tensors |
| 201 | + """ |
| 202 | + dist_group = dist.group.WORLD |
| 203 | + assert dist_group is not None |
| 204 | + |
| 205 | + world_size = dist.get_world_size() |
| 206 | + a_shared = symm_mem.empty(N // world_size, dtype=dtype, device=device).normal_() |
| 207 | + |
| 208 | + a_shared_clone = symm_mem.empty( |
| 209 | + a_shared.shape, |
| 210 | + dtype=a_shared.dtype, |
| 211 | + device=a_shared.device, |
| 212 | + ) |
| 213 | + symm_mem.rendezvous(a_shared_clone, dist_group.group_name) |
| 214 | + a_shared_clone.copy_(a_shared) |
| 215 | + |
| 216 | + a_out = helion_one_shot_all_reduce(a_shared) |
| 217 | + |
| 218 | + gloden_o = torch.ops.symm_mem.one_shot_all_reduce( |
| 219 | + a_shared_clone, "sum", dist_group.group_name |
| 220 | + ) |
| 221 | + |
| 222 | + torch.testing.assert_close(a_out, gloden_o, rtol=1e-1, atol=1e-1) |
| 223 | + |
| 224 | + |
| 225 | +def main() -> None: |
| 226 | + """ |
| 227 | + Main entry point for the all-reduce example. |
| 228 | +
|
| 229 | + Sets up the distributed environment, initializes CUDA devices, and runs the |
| 230 | + all-reduce test, and then clean up. |
| 231 | + """ |
| 232 | + rank = int(os.environ["LOCAL_RANK"]) |
| 233 | + torch.manual_seed(42 + rank) |
| 234 | + device = torch.device(f"cuda:{rank}") |
| 235 | + torch.cuda.set_device(device) |
| 236 | + dist.init_process_group("nccl") |
| 237 | + test(16384, device, torch.bfloat16) |
| 238 | + |
| 239 | + dist.destroy_process_group() |
| 240 | + |
| 241 | + |
| 242 | +if __name__ == "__main__": |
| 243 | + """ |
| 244 | + Run with: |
| 245 | + torchrun \ |
| 246 | + --nnodes 1 --nproc-per-node 8 \ |
| 247 | + --rdzv-backend c10d --rdzv-endpoint localhost:0 \ |
| 248 | + --no_python python3 examples/all_reduce.py |
| 249 | + """ |
| 250 | + main() |
0 commit comments