|
| 1 | +```{eval-rst} |
| 2 | +.. role:: hidden |
| 3 | + :class: hidden-section |
| 4 | +``` |
| 5 | + |
| 6 | +# PyTorch Symmetric Memory |
| 7 | + |
| 8 | +:::{note} |
| 9 | +`torch.distributed._symmetric_memory` is currently in alpha state and under |
| 10 | +development. API changes may be possible. |
| 11 | +::: |
| 12 | + |
| 13 | +## Why Symmetric Memory? |
| 14 | + |
| 15 | +With rapidly evolving parallelization techniques, existing frameworks and |
| 16 | +libraries often struggle to keep up, and developers increasingly rely on custom |
| 17 | +implementations directly scheduling communications and computations. In recent |
| 18 | +years we’ve witnessed a shift from primarily relying on one-dimensional |
| 19 | +data-parallelism techniques to multi-dimensional parallelism ones. The latter |
| 20 | +have different latency requirements for different types of communications and |
| 21 | +thus require fine-grained overlapping of compute and communications. |
| 22 | + |
| 23 | +To minimize compute interference, they also require the use of copy engines and |
| 24 | +network interface cards (NICs) to drive communication. Network transport |
| 25 | +protocols such as remote direct memory access (RDMA) enhance the performance by |
| 26 | +enabling direct, high-speed, and low-latency communication between processors |
| 27 | +and memory. This increase in variety indicates the need for finer-grained |
| 28 | +communication primitives than are offered today by high-level collective APIs, |
| 29 | +ones that would enable developers to implement specific algorithms tailored for |
| 30 | +their use cases, such as low-latency collectives, fine-grained |
| 31 | +compute-communications overlap, or custom fusions. |
| 32 | + |
| 33 | +Furthermore, today’s advanced AI systems connect GPUs with high-bandwidth links |
| 34 | +(such as NVLinks, InfiniBand or RoCE), making GPU global memory directly |
| 35 | +accessible to peers. Such connections present a great opportunity for |
| 36 | +programmers to program the system as a single, gigantic GPU with vast accessible |
| 37 | +memory, instead of programming singular “GPU islands.” |
| 38 | + |
| 39 | +In this document, we will show how you can use PyTorch Symmetric Memory to |
| 40 | +program modern GPU systems as a “single GPU” and achieve fine-grained remote |
| 41 | +access. |
| 42 | + |
| 43 | +## What PyTorch Symmetric Memory unlocks? |
| 44 | + |
| 45 | +PyTorch Symmetric Memory unlocks three new capabilities: |
| 46 | + |
| 47 | +- **Customized communication patterns**: Increased flexibility in kernel writing |
| 48 | +allows developers to write custom kernels that implement their custom |
| 49 | +computations and communications, directly tailored to the need of the |
| 50 | +application. It will also be straightforward to add support for new data types |
| 51 | +along with the special compute that those data types might require, even if it’s |
| 52 | +not present yet in the standard libraries. |
| 53 | + |
| 54 | +- **In-kernel compute-comm fusion**: Device-initiated communication capability |
| 55 | +allows developers to write kernels with both computation and communication |
| 56 | +instructions, allowing for the fusion of computation and data movement in the |
| 57 | +smallest possible granularity. |
| 58 | + |
| 59 | +- **Low-latency remote access**: Network transport protocols like RDMA enhance the |
| 60 | +performance of symmetric memory in networked environments by enabling direct, |
| 61 | +high-speed, and low-latency communication between processors and memory. RDMA |
| 62 | +eliminates the overhead associated with the traditional network stack and CPU |
| 63 | +involvement. It also offloads data transfer from the compute to the NICs, |
| 64 | +freeing up compute resources for computational tasks. |
| 65 | + |
| 66 | +Next, we will show you how PyTorch Symmetric Memory (SymmMem) enables new |
| 67 | +applications with the above capabilities. |
| 68 | + |
| 69 | +## A “Hello World” example |
| 70 | + |
| 71 | +The PyTorch SymmMem programming model involves two key elements: |
| 72 | + |
| 73 | +- creating symmetric tensors |
| 74 | +- creating SymmMem kernels |
| 75 | + |
| 76 | +To create symmetric tensors, one can use the |
| 77 | +`torch.distributed._symmetric_memory` package: |
| 78 | + |
| 79 | +```python |
| 80 | +import torch.distributed._symmetric_memory as symm_mem |
| 81 | + |
| 82 | +t = symm_mem.empty(128, device=torch.device("cuda", rank)) |
| 83 | +hdl = symm_mem.rendezvous(t, group) |
| 84 | +``` |
| 85 | + |
| 86 | +The `symm_mem.empty` function creates a tensor that is backed by a symmetric |
| 87 | +memory allocation. The `rendezvous` function establishes a rendezvous with peers |
| 88 | +in the group, and returns a handle to the symmetric memory allocation. The |
| 89 | +handle provides method to access information related to the symmetric memory |
| 90 | +allocation, such as pointers to symmetric buffer on peer ranks, multicast |
| 91 | +pointer (if supported), and signal pads. |
| 92 | + |
| 93 | +The `empty` and `rendezvous` functions must be called in the same order on all |
| 94 | +ranks in the group. |
| 95 | + |
| 96 | +Then, collectives can be called on these tensors. For example, to perform a |
| 97 | +one-shot all-reduce: |
| 98 | + |
| 99 | +```python |
| 100 | +# Most SymmMem ops are under the torch.ops.symm_mem namespace |
| 101 | +torch.ops.symm_mem.one_shot_all_reduce(t, "sum", group) |
| 102 | +``` |
| 103 | + |
| 104 | +Please note that `torch.ops.symm_mem` is an "op namespace" instead of a python |
| 105 | +module. Therefore, you can't import it by `import torch.ops.symm_mem`, neither |
| 106 | +can you import an op by `from torch.ops.symm_mem import one_shot_all_reduce`. |
| 107 | +You can call the op directly as in the example above. |
| 108 | + |
| 109 | +## Write your own kernel |
| 110 | + |
| 111 | +To write your own kernel doing communications with symmetric memory, you’ll need |
| 112 | +access to the addresses of mapped peer buffers and access to signal pads that |
| 113 | +are required for synchronization. In the kernel you’ll also need to perform |
| 114 | +correct synchronizations to make sure that peers are ready for communication, |
| 115 | +and signal to them that this GPU is ready. |
| 116 | + |
| 117 | +PyTorch Symmetric Memory provides CUDA Graph-compatible synchronization |
| 118 | +primitives that operate on the signal pad accompanying each symmetric memory |
| 119 | +allocation. Kernels using symmetric memory can be written both in CUDA and in |
| 120 | +Triton. Here’s an example allocating symmetric tensor and exchanging handles: |
| 121 | + |
| 122 | +```python |
| 123 | +import torch.distributed._symmetric_memory as symm_mem |
| 124 | + |
| 125 | +dist.init_process_group() |
| 126 | +rank = dist.get_rank() |
| 127 | + |
| 128 | +# Allocate a tensor |
| 129 | +t = symm_mem.empty(4096, device=f"cuda:{rank}") |
| 130 | +# Establish symmetric memory and obtain the handle |
| 131 | +hdl = symm_mem.rendezvous(t, dist.group.WORLD) |
| 132 | +``` |
| 133 | + |
| 134 | +Access to buffer pointers, multimem pointer, and signal pads is provided via: |
| 135 | + |
| 136 | +```python |
| 137 | +hdl.buffer_ptrs |
| 138 | +hdl.multicast_ptr |
| 139 | +hdl.signal_pad_ptrs |
| 140 | +``` |
| 141 | + |
| 142 | +Data pointed to by `buffer_ptrs` can be accessed just like regular local data, |
| 143 | +and any necessary compute can also be performed in the usual ways. As with local |
| 144 | +data, you can and should use vectorized accesses to improve efficiency. |
| 145 | + |
| 146 | +Symmetric memory is especially convenient for writing kernels in Triton. While |
| 147 | +previously Triton removed the barriers to writing efficient CUDA code, now |
| 148 | +communications can be added easily to Triton kernels. The kernel below |
| 149 | +demonstrates a low-latency, all-reduce kernel written in Triton. |
| 150 | + |
| 151 | +```python |
| 152 | +@triton.jit |
| 153 | +def one_shot_all_reduce_kernel( |
| 154 | + buf_tuple, |
| 155 | + signal_pad_ptrs, |
| 156 | + output_ptr, |
| 157 | + numel: tl.constexpr, |
| 158 | + rank: tl.constexpr, |
| 159 | + world_size: tl.constexpr, |
| 160 | + BLOCK_SIZE: tl.constexpr, |
| 161 | +): |
| 162 | + ptx_utils.symm_mem_sync( |
| 163 | + signal_pad_ptrs, None, rank, world_size, hasSubsequenceMemAccess=True |
| 164 | + ) |
| 165 | + |
| 166 | + pid = tl.program_id(axis=0) |
| 167 | + block_start = pid * BLOCK_SIZE |
| 168 | + |
| 169 | + while block_start < numel: |
| 170 | + offsets = block_start + tl.arange(0, BLOCK_SIZE) |
| 171 | + mask = offsets < numel |
| 172 | + acc = tl.zeros((BLOCK_SIZE,), dtype=tl.bfloat16) |
| 173 | + |
| 174 | + for i in tl.static_range(world_size): |
| 175 | + buffer_rank = buf_tuple[i] |
| 176 | + x = tl.load(buffer_rank + offsets, mask=mask) |
| 177 | + acc += x |
| 178 | + |
| 179 | + tl.store(output_ptr + offsets, acc, mask=mask) |
| 180 | + block_start += tl.num_programs(axis=0) * BLOCK_SIZE |
| 181 | + |
| 182 | + ptx_utils.symm_mem_sync( |
| 183 | + signal_pad_ptrs, None, rank, world_size, hasPreviousMemAccess=True |
| 184 | + ) |
| 185 | +``` |
| 186 | + |
| 187 | +Synchronizations at the beginning and the end of the kernel above guarantee that |
| 188 | +all the processes see consistent data. The bulk of the kernel is recognizable |
| 189 | +Triton code, and Triton will optimize it behind the scene, making sure memory |
| 190 | +accesses are performed in an efficient way with vectorization and unrolling. As |
| 191 | +with all Triton kernels, it is easily modifiable to add extra computations or |
| 192 | +change the communication algorithm. Visit |
| 193 | +https://github.com/meta-pytorch/kraken/blob/main/kraken to see additional |
| 194 | +utilities and examples of using symmetric memory to implement common patterns in |
| 195 | +Triton. |
| 196 | + |
| 197 | +## Scale out |
| 198 | + |
| 199 | +Large language models distribute experts onto more than 8 GPUs, hence requiring |
| 200 | +multi-node access capability. NICs capable of RDMA come to help. In addition, |
| 201 | +software libraries such as NVSHMEM or rocSHMEM abstract away the programming |
| 202 | +difference between intra-node access and inter-node access with primitives that |
| 203 | +are slightly higher level than pointer access, such as put and get. |
| 204 | + |
| 205 | +PyTorch provides NVSHMEM plugins to augment Triton kernels’ cross-node |
| 206 | +capabilities. As shown in the code snippet below, one can initiate a cross-node |
| 207 | +put command within the kernel. |
| 208 | + |
| 209 | +```python |
| 210 | +import torch.distributed._symmetric_memory._nvshmem_triton as nvshmem |
| 211 | +from torch.distributed._symmetric_memory._nvshmem_triton import requires_nvshmem |
| 212 | + |
| 213 | +@requires_nvshmem |
| 214 | +@triton.jit |
| 215 | +def my_put_kernel( |
| 216 | + dest, |
| 217 | + src, |
| 218 | + nelems, |
| 219 | + pe, |
| 220 | +): |
| 221 | + nvshmem.put(dest, src, nelems, pe) |
| 222 | +``` |
| 223 | + |
| 224 | +The `requires_nvshmem` decorator is used to indicate that the kernel requires |
| 225 | +the NVSHMEM device library as an external dependency. When Triton compiles the |
| 226 | +kernel, the decorator will search your system paths for the NVSHMEM device |
| 227 | +library. If it is available, Triton will include the necessary device assembly |
| 228 | +to use the NVSHMEM functions. |
| 229 | + |
| 230 | +## API Reference |
| 231 | + |
| 232 | +```{eval-rst} |
| 233 | +.. currentmodule:: torch.distributed._symmetric_memory |
| 234 | +``` |
| 235 | + |
| 236 | +```{eval-rst} |
| 237 | +.. autofunction:: empty |
| 238 | +``` |
| 239 | + |
| 240 | +```{eval-rst} |
| 241 | +.. autofunction:: rendezvous |
| 242 | +``` |
| 243 | + |
| 244 | +```{eval-rst} |
| 245 | +.. autofunction:: is_nvshmem_available |
| 246 | +``` |
| 247 | + |
| 248 | +```{eval-rst} |
| 249 | +.. autofunction:: set_backend |
| 250 | +``` |
| 251 | + |
| 252 | +```{eval-rst} |
| 253 | +.. autofunction:: get_backend |
| 254 | +``` |
| 255 | + |
| 256 | +## Op Reference |
| 257 | +:::{note} |
| 258 | +The following ops are hosted in the `torch.ops.symm_mem` namespace. You can call |
| 259 | +them directly via `torch.ops.symm_mem.<op_name>`. |
| 260 | +::: |
| 261 | + |
| 262 | +```{eval-rst} |
| 263 | +.. currentmodule:: torch.ops.symm_mem |
| 264 | +``` |
| 265 | + |
| 266 | +```{eval-rst} |
| 267 | +.. py:function:: multimem_all_reduce_(input: Tensor, reduce_op: str, group_name: str) -> Tensor |
| 268 | +
|
| 269 | + Performs a multimem all-reduce operation on the input tensor. This operation |
| 270 | + requires hardware support for multimem operations. On NVIDIA GPUs, NVLink |
| 271 | + SHARP is required. |
| 272 | +
|
| 273 | + :param Tensor input: Input tensor to perform all-reduce on. Must be symmetric. |
| 274 | + :param str reduce_op: Reduction operation to perform. Currently only "sum" is supported. |
| 275 | + :param str group_name: Name of the group to perform all-reduce on. |
| 276 | +
|
| 277 | +
|
| 278 | +.. py:function:: multimem_all_gather_out(input: Tensor, group_name: str, out: Tensor) -> Tensor |
| 279 | +
|
| 280 | + Performs a multimem all-gather operation on the input tensor. This operation requires hardware support for multimem operations. On NVIDIA GPUs, NVLink SHARP is required. |
| 281 | +
|
| 282 | + :param Tensor input: Input tensor to perform all-gather on. |
| 283 | + :param str group_name: Name of the group to perform all-gather on. |
| 284 | + :param Tensor out: Output tensor to store the result of the all-gather operation. Must be symmetric. |
| 285 | +
|
| 286 | +
|
| 287 | +.. py:function:: one_shot_all_reduce(input: Tensor, reduce_op: str, group_name: str) -> Tensor |
| 288 | +
|
| 289 | + Performs a one-shot all-reduce operation on the input tensor. |
| 290 | +
|
| 291 | + :param Tensor input: Input tensor to perform all-reduce on. Must be symmetric. |
| 292 | + :param str reduce_op: Reduction operation to perform. Currently only "sum" is supported. |
| 293 | + :param str group_name: Name of the group to perform all-reduce on. |
| 294 | +
|
| 295 | +
|
| 296 | +.. py:function:: one_shot_all_reduce_out(input: Tensor, reduce_op: str, group_name: str, out: Tensor) -> Tensor |
| 297 | +
|
| 298 | + Performs a one-shot all-reduce operation based on the input tensor and writes the result to the output tensor. |
| 299 | +
|
| 300 | + :param Tensor input: Input tensor to perform all-reduce on. Must be symmetric. |
| 301 | + :param str reduce_op: Reduction operation to perform. Currently only "sum" is supported. |
| 302 | + :param str group_name: Name of the group to perform all-reduce on. |
| 303 | + :param Tensor out: Output tensor to store the result of the all-reduce operation. Can be a regular tensor. |
| 304 | +
|
| 305 | +
|
| 306 | +.. py:function:: two_shot_all_reduce_(input: Tensor, reduce_op: str, group_name: str) -> Tensor |
| 307 | +
|
| 308 | + Performs a two-shot all-reduce operation on the input tensor. |
| 309 | +
|
| 310 | + :param Tensor input: Input tensor to perform all-reduce on. Must be symmetric. |
| 311 | + :param str reduce_op: Reduction operation to perform. Currently only "sum" is supported. |
| 312 | + :param str group_name: Name of the group to perform all-reduce on. |
| 313 | +
|
| 314 | +
|
| 315 | +.. py:function:: all_to_all_vdev(input: Tensor, out: Tensor, in_splits: Tensor, out_splits_offsets: Tensor, group_name: str) -> None |
| 316 | +
|
| 317 | + Performs an all-to-all-v operation using NVSHMEM, with split information provided on device. |
| 318 | +
|
| 319 | + :param Tensor input: Input tensor to perform all-to-all on. Must be symmetric. |
| 320 | + :param Tensor out: Output tensor to store the result of the all-to-all operation. Must be symmetric. |
| 321 | + :param Tensor in_splits: Tensor containing splits of data to send to each peer. Must be symmetric. Must be of size (group_size,). The splits are in the unit of elements in the 1st dimension. |
| 322 | + :param Tensor out_splits_offsets: Tensor containing the splits and offsets of data received from each peer. Must be symmetric. Must be of size (2, group_size). The rows are (in order): output splits and output offsets. |
| 323 | + :param str group_name: Name of the group to perform all-to-all on. |
| 324 | +
|
| 325 | +
|
| 326 | +.. py:function:: all_to_all_vdev_2d(input: Tensor, out: Tensor, in_splits: Tensor, out_splits_offsets: Tensor, group_name: str, [major_align: int = None]) -> None |
| 327 | +
|
| 328 | + Perform a 2D all-to-all-v operation using NVSHMEM, with split information provided on device. In Mixture of Experts models, this operation can be used to dispatch tokens. |
| 329 | +
|
| 330 | + :param Tensor input: Input tensor to perform all-to-all on. Must be symmetric. |
| 331 | + :param Tensor out: Output tensor to store the result of the all-to-all operation. Must be symmetric. |
| 332 | + :param Tensor in_splits: Tensor containing the splits of data to send to each expert. Must be symmetric. Must be of size (group_size * ne,), where ne is the number of experts per rank. The splits are in the unit of elements in the 1st dimension. |
| 333 | + :param Tensor out_splits_offsets: Tensor containing the splits and offsets of data received from each peer. Must be symmetric. Must be of size (2, group_size * ne). The rows are (in order): output splits and output offsets. |
| 334 | + :param str group_name: Name of the group to perform all-to-all on. |
| 335 | + :param int major_align: Optional alignment for the major dimension of the output chunk for each expert. If not provided, the alignment is assumed to be 1. Any alignment adjustment will be reflected in the output offsets. |
| 336 | +
|
| 337 | + A 2D AllToAllv shuffle is illustrated below: |
| 338 | + (world_size = 2, ne = 2, total number of experts = 4):: |
| 339 | +
|
| 340 | + Source: | Rank 0 | Rank 1 | |
| 341 | + | c0 | c1 | c2 | c3 | d0 | d1 | d2 | d3 | |
| 342 | +
|
| 343 | + Dest : | Rank 0 | Rank 1 | |
| 344 | + | c0 | d0 | c1 | d1 | c2 | d2 | c3 | d3 | |
| 345 | +
|
| 346 | + where each `c_i` / `d_i` are slices of the `input` tensor, targeting expert |
| 347 | + `i`, with length indicated by input splits. That is, the 2D AllToAllv |
| 348 | + shuffle achieves a transpose from rank-major order at input to expert-major |
| 349 | + order at output. |
| 350 | +
|
| 351 | + If `major_align` is not 1, the output offsets of c1, c2, c3 will be |
| 352 | + up-aligned to this value. For example, if c0 has length 5 and d0 has |
| 353 | + length 7 (making a total of 12), and if the `major_align` is set to 16, |
| 354 | + the output offset of c1 will be 16. Similar for c2 and c3. This value has |
| 355 | + no effect on the offset of the minor dimension, i.e. d0, d1, d2 and d3. |
| 356 | + Note: since cutlass does not support empty bins, we set the aligned length |
| 357 | + to `major_align` if it is 0. See |
| 358 | + https://github.com/pytorch/pytorch/issues/152668. |
| 359 | +
|
| 360 | +
|
| 361 | +.. py:function:: all_to_all_vdev_2d_offset(Tensor input, Tensor out, Tensor in_splits_offsets, Tensor out_splits_offsets, str group_name) -> None |
| 362 | +
|
| 363 | + Perform a 2D AllToAllv shuffle operation, with input split and offset |
| 364 | + information provided on device. The input offsets are not required to be |
| 365 | + exact prefix sum of the input splits, i.e. paddings are allowed between the |
| 366 | + split chunks. The paddings, however, will not be transferred to peer |
| 367 | + ranks. |
| 368 | +
|
| 369 | + In Mixture of Experts models, this operation can be used to combine tokens |
| 370 | + processed by experts on parallel ranks. This operation can be viewed as an |
| 371 | + "reverse" operation to the `all_to_all_vdev_2d` operation (which shuffles |
| 372 | + tokens to experts). |
| 373 | +
|
| 374 | + :param Tensor input: Input tensor to perform all-to-all on. Must be symmetric. |
| 375 | + :param Tensor out: Output tensor to store the result of the all-to-all operation. Must be symmetric. |
| 376 | + :param Tensor in_splits_offsets: Tensor containing the splits and offsets of data to send to each expert. Must be symmetric. Must be of size (2, group_size * ne), where `ne` is the number of experts. The rows are (in order): input splits and input offsets. The splits are in the unit of elements in the 1st dimension. |
| 377 | + :param Tensor out_splits_offsets: Tensor containing the splits and offsets of data received from each peer. Must be symmetric. Must be of size (2, group_size * ne). The rows are (in order): output splits and output offsets. |
| 378 | + :param str group_name: Name of the group to perform all-to-all on. |
| 379 | +
|
| 380 | +``` |
0 commit comments