Skip to content

Commit b002562

Browse files
pytorchbotkwen2501
andauthored
Add doc for Symmetric Memory (pytorch#167477)
Add doc for Symmetric Memory (pytorch#166148) Pull Request resolved: pytorch#166148 Approved by: https://github.com/fduwjj (cherry picked from commit 1e2e7cb) Co-authored-by: Ke Wen <[email protected]>
1 parent 5811a8d commit b002562

File tree

3 files changed

+386
-7
lines changed

3 files changed

+386
-7
lines changed

docs/source/pytorch-api.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ torch.distributed.fsdp.fully_shard <distributed.fsdp.fully_shard>
4141
torch.distributed.tensor.parallel <distributed.tensor.parallel>
4242
torch.distributed.optim <distributed.optim>
4343
torch.distributed.pipelining <distributed.pipelining>
44+
torch.distributed._symmetric_memory <symmetric_memory>
4445
torch.distributed.checkpoint <distributed.checkpoint>
4546
torch.distributions <distributions>
4647
torch.compiler <torch.compiler>

docs/source/symmetric_memory.md

Lines changed: 380 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,380 @@
1+
```{eval-rst}
2+
.. role:: hidden
3+
:class: hidden-section
4+
```
5+
6+
# PyTorch Symmetric Memory
7+
8+
:::{note}
9+
`torch.distributed._symmetric_memory` is currently in alpha state and under
10+
development. API changes may be possible.
11+
:::
12+
13+
## Why Symmetric Memory?
14+
15+
With rapidly evolving parallelization techniques, existing frameworks and
16+
libraries often struggle to keep up, and developers increasingly rely on custom
17+
implementations directly scheduling communications and computations. In recent
18+
years we’ve witnessed a shift from primarily relying on one-dimensional
19+
data-parallelism techniques to multi-dimensional parallelism ones. The latter
20+
have different latency requirements for different types of communications and
21+
thus require fine-grained overlapping of compute and communications.
22+
23+
To minimize compute interference, they also require the use of copy engines and
24+
network interface cards (NICs) to drive communication. Network transport
25+
protocols such as remote direct memory access (RDMA) enhance the performance by
26+
enabling direct, high-speed, and low-latency communication between processors
27+
and memory. This increase in variety indicates the need for finer-grained
28+
communication primitives than are offered today by high-level collective APIs,
29+
ones that would enable developers to implement specific algorithms tailored for
30+
their use cases, such as low-latency collectives, fine-grained
31+
compute-communications overlap, or custom fusions.
32+
33+
Furthermore, today’s advanced AI systems connect GPUs with high-bandwidth links
34+
(such as NVLinks, InfiniBand or RoCE), making GPU global memory directly
35+
accessible to peers. Such connections present a great opportunity for
36+
programmers to program the system as a single, gigantic GPU with vast accessible
37+
memory, instead of programming singular “GPU islands.”
38+
39+
In this document, we will show how you can use PyTorch Symmetric Memory to
40+
program modern GPU systems as a “single GPU” and achieve fine-grained remote
41+
access.
42+
43+
## What PyTorch Symmetric Memory unlocks?
44+
45+
PyTorch Symmetric Memory unlocks three new capabilities:
46+
47+
- **Customized communication patterns**: Increased flexibility in kernel writing
48+
allows developers to write custom kernels that implement their custom
49+
computations and communications, directly tailored to the need of the
50+
application. It will also be straightforward to add support for new data types
51+
along with the special compute that those data types might require, even if it’s
52+
not present yet in the standard libraries.
53+
54+
- **In-kernel compute-comm fusion**: Device-initiated communication capability
55+
allows developers to write kernels with both computation and communication
56+
instructions, allowing for the fusion of computation and data movement in the
57+
smallest possible granularity.
58+
59+
- **Low-latency remote access**: Network transport protocols like RDMA enhance the
60+
performance of symmetric memory in networked environments by enabling direct,
61+
high-speed, and low-latency communication between processors and memory. RDMA
62+
eliminates the overhead associated with the traditional network stack and CPU
63+
involvement. It also offloads data transfer from the compute to the NICs,
64+
freeing up compute resources for computational tasks.
65+
66+
Next, we will show you how PyTorch Symmetric Memory (SymmMem) enables new
67+
applications with the above capabilities.
68+
69+
## A “Hello World” example
70+
71+
The PyTorch SymmMem programming model involves two key elements:
72+
73+
- creating symmetric tensors
74+
- creating SymmMem kernels
75+
76+
To create symmetric tensors, one can use the
77+
`torch.distributed._symmetric_memory` package:
78+
79+
```python
80+
import torch.distributed._symmetric_memory as symm_mem
81+
82+
t = symm_mem.empty(128, device=torch.device("cuda", rank))
83+
hdl = symm_mem.rendezvous(t, group)
84+
```
85+
86+
The `symm_mem.empty` function creates a tensor that is backed by a symmetric
87+
memory allocation. The `rendezvous` function establishes a rendezvous with peers
88+
in the group, and returns a handle to the symmetric memory allocation. The
89+
handle provides method to access information related to the symmetric memory
90+
allocation, such as pointers to symmetric buffer on peer ranks, multicast
91+
pointer (if supported), and signal pads.
92+
93+
The `empty` and `rendezvous` functions must be called in the same order on all
94+
ranks in the group.
95+
96+
Then, collectives can be called on these tensors. For example, to perform a
97+
one-shot all-reduce:
98+
99+
```python
100+
# Most SymmMem ops are under the torch.ops.symm_mem namespace
101+
torch.ops.symm_mem.one_shot_all_reduce(t, "sum", group)
102+
```
103+
104+
Please note that `torch.ops.symm_mem` is an "op namespace" instead of a python
105+
module. Therefore, you can't import it by `import torch.ops.symm_mem`, neither
106+
can you import an op by `from torch.ops.symm_mem import one_shot_all_reduce`.
107+
You can call the op directly as in the example above.
108+
109+
## Write your own kernel
110+
111+
To write your own kernel doing communications with symmetric memory, you’ll need
112+
access to the addresses of mapped peer buffers and access to signal pads that
113+
are required for synchronization. In the kernel you’ll also need to perform
114+
correct synchronizations to make sure that peers are ready for communication,
115+
and signal to them that this GPU is ready.
116+
117+
PyTorch Symmetric Memory provides CUDA Graph-compatible synchronization
118+
primitives that operate on the signal pad accompanying each symmetric memory
119+
allocation. Kernels using symmetric memory can be written both in CUDA and in
120+
Triton. Here’s an example allocating symmetric tensor and exchanging handles:
121+
122+
```python
123+
import torch.distributed._symmetric_memory as symm_mem
124+
125+
dist.init_process_group()
126+
rank = dist.get_rank()
127+
128+
# Allocate a tensor
129+
t = symm_mem.empty(4096, device=f"cuda:{rank}")
130+
# Establish symmetric memory and obtain the handle
131+
hdl = symm_mem.rendezvous(t, dist.group.WORLD)
132+
```
133+
134+
Access to buffer pointers, multimem pointer, and signal pads is provided via:
135+
136+
```python
137+
hdl.buffer_ptrs
138+
hdl.multicast_ptr
139+
hdl.signal_pad_ptrs
140+
```
141+
142+
Data pointed to by `buffer_ptrs` can be accessed just like regular local data,
143+
and any necessary compute can also be performed in the usual ways. As with local
144+
data, you can and should use vectorized accesses to improve efficiency.
145+
146+
Symmetric memory is especially convenient for writing kernels in Triton. While
147+
previously Triton removed the barriers to writing efficient CUDA code, now
148+
communications can be added easily to Triton kernels. The kernel below
149+
demonstrates a low-latency, all-reduce kernel written in Triton.
150+
151+
```python
152+
@triton.jit
153+
def one_shot_all_reduce_kernel(
154+
buf_tuple,
155+
signal_pad_ptrs,
156+
output_ptr,
157+
numel: tl.constexpr,
158+
rank: tl.constexpr,
159+
world_size: tl.constexpr,
160+
BLOCK_SIZE: tl.constexpr,
161+
):
162+
ptx_utils.symm_mem_sync(
163+
signal_pad_ptrs, None, rank, world_size, hasSubsequenceMemAccess=True
164+
)
165+
166+
pid = tl.program_id(axis=0)
167+
block_start = pid * BLOCK_SIZE
168+
169+
while block_start < numel:
170+
offsets = block_start + tl.arange(0, BLOCK_SIZE)
171+
mask = offsets < numel
172+
acc = tl.zeros((BLOCK_SIZE,), dtype=tl.bfloat16)
173+
174+
for i in tl.static_range(world_size):
175+
buffer_rank = buf_tuple[i]
176+
x = tl.load(buffer_rank + offsets, mask=mask)
177+
acc += x
178+
179+
tl.store(output_ptr + offsets, acc, mask=mask)
180+
block_start += tl.num_programs(axis=0) * BLOCK_SIZE
181+
182+
ptx_utils.symm_mem_sync(
183+
signal_pad_ptrs, None, rank, world_size, hasPreviousMemAccess=True
184+
)
185+
```
186+
187+
Synchronizations at the beginning and the end of the kernel above guarantee that
188+
all the processes see consistent data. The bulk of the kernel is recognizable
189+
Triton code, and Triton will optimize it behind the scene, making sure memory
190+
accesses are performed in an efficient way with vectorization and unrolling. As
191+
with all Triton kernels, it is easily modifiable to add extra computations or
192+
change the communication algorithm. Visit
193+
https://github.com/meta-pytorch/kraken/blob/main/kraken to see additional
194+
utilities and examples of using symmetric memory to implement common patterns in
195+
Triton.
196+
197+
## Scale out
198+
199+
Large language models distribute experts onto more than 8 GPUs, hence requiring
200+
multi-node access capability. NICs capable of RDMA come to help. In addition,
201+
software libraries such as NVSHMEM or rocSHMEM abstract away the programming
202+
difference between intra-node access and inter-node access with primitives that
203+
are slightly higher level than pointer access, such as put and get.
204+
205+
PyTorch provides NVSHMEM plugins to augment Triton kernels’ cross-node
206+
capabilities. As shown in the code snippet below, one can initiate a cross-node
207+
put command within the kernel.
208+
209+
```python
210+
import torch.distributed._symmetric_memory._nvshmem_triton as nvshmem
211+
from torch.distributed._symmetric_memory._nvshmem_triton import requires_nvshmem
212+
213+
@requires_nvshmem
214+
@triton.jit
215+
def my_put_kernel(
216+
dest,
217+
src,
218+
nelems,
219+
pe,
220+
):
221+
nvshmem.put(dest, src, nelems, pe)
222+
```
223+
224+
The `requires_nvshmem` decorator is used to indicate that the kernel requires
225+
the NVSHMEM device library as an external dependency. When Triton compiles the
226+
kernel, the decorator will search your system paths for the NVSHMEM device
227+
library. If it is available, Triton will include the necessary device assembly
228+
to use the NVSHMEM functions.
229+
230+
## API Reference
231+
232+
```{eval-rst}
233+
.. currentmodule:: torch.distributed._symmetric_memory
234+
```
235+
236+
```{eval-rst}
237+
.. autofunction:: empty
238+
```
239+
240+
```{eval-rst}
241+
.. autofunction:: rendezvous
242+
```
243+
244+
```{eval-rst}
245+
.. autofunction:: is_nvshmem_available
246+
```
247+
248+
```{eval-rst}
249+
.. autofunction:: set_backend
250+
```
251+
252+
```{eval-rst}
253+
.. autofunction:: get_backend
254+
```
255+
256+
## Op Reference
257+
:::{note}
258+
The following ops are hosted in the `torch.ops.symm_mem` namespace. You can call
259+
them directly via `torch.ops.symm_mem.<op_name>`.
260+
:::
261+
262+
```{eval-rst}
263+
.. currentmodule:: torch.ops.symm_mem
264+
```
265+
266+
```{eval-rst}
267+
.. py:function:: multimem_all_reduce_(input: Tensor, reduce_op: str, group_name: str) -> Tensor
268+
269+
Performs a multimem all-reduce operation on the input tensor. This operation
270+
requires hardware support for multimem operations. On NVIDIA GPUs, NVLink
271+
SHARP is required.
272+
273+
:param Tensor input: Input tensor to perform all-reduce on. Must be symmetric.
274+
:param str reduce_op: Reduction operation to perform. Currently only "sum" is supported.
275+
:param str group_name: Name of the group to perform all-reduce on.
276+
277+
278+
.. py:function:: multimem_all_gather_out(input: Tensor, group_name: str, out: Tensor) -> Tensor
279+
280+
Performs a multimem all-gather operation on the input tensor. This operation requires hardware support for multimem operations. On NVIDIA GPUs, NVLink SHARP is required.
281+
282+
:param Tensor input: Input tensor to perform all-gather on.
283+
:param str group_name: Name of the group to perform all-gather on.
284+
:param Tensor out: Output tensor to store the result of the all-gather operation. Must be symmetric.
285+
286+
287+
.. py:function:: one_shot_all_reduce(input: Tensor, reduce_op: str, group_name: str) -> Tensor
288+
289+
Performs a one-shot all-reduce operation on the input tensor.
290+
291+
:param Tensor input: Input tensor to perform all-reduce on. Must be symmetric.
292+
:param str reduce_op: Reduction operation to perform. Currently only "sum" is supported.
293+
:param str group_name: Name of the group to perform all-reduce on.
294+
295+
296+
.. py:function:: one_shot_all_reduce_out(input: Tensor, reduce_op: str, group_name: str, out: Tensor) -> Tensor
297+
298+
Performs a one-shot all-reduce operation based on the input tensor and writes the result to the output tensor.
299+
300+
:param Tensor input: Input tensor to perform all-reduce on. Must be symmetric.
301+
:param str reduce_op: Reduction operation to perform. Currently only "sum" is supported.
302+
:param str group_name: Name of the group to perform all-reduce on.
303+
:param Tensor out: Output tensor to store the result of the all-reduce operation. Can be a regular tensor.
304+
305+
306+
.. py:function:: two_shot_all_reduce_(input: Tensor, reduce_op: str, group_name: str) -> Tensor
307+
308+
Performs a two-shot all-reduce operation on the input tensor.
309+
310+
:param Tensor input: Input tensor to perform all-reduce on. Must be symmetric.
311+
:param str reduce_op: Reduction operation to perform. Currently only "sum" is supported.
312+
:param str group_name: Name of the group to perform all-reduce on.
313+
314+
315+
.. py:function:: all_to_all_vdev(input: Tensor, out: Tensor, in_splits: Tensor, out_splits_offsets: Tensor, group_name: str) -> None
316+
317+
Performs an all-to-all-v operation using NVSHMEM, with split information provided on device.
318+
319+
:param Tensor input: Input tensor to perform all-to-all on. Must be symmetric.
320+
:param Tensor out: Output tensor to store the result of the all-to-all operation. Must be symmetric.
321+
:param Tensor in_splits: Tensor containing splits of data to send to each peer. Must be symmetric. Must be of size (group_size,). The splits are in the unit of elements in the 1st dimension.
322+
:param Tensor out_splits_offsets: Tensor containing the splits and offsets of data received from each peer. Must be symmetric. Must be of size (2, group_size). The rows are (in order): output splits and output offsets.
323+
:param str group_name: Name of the group to perform all-to-all on.
324+
325+
326+
.. py:function:: all_to_all_vdev_2d(input: Tensor, out: Tensor, in_splits: Tensor, out_splits_offsets: Tensor, group_name: str, [major_align: int = None]) -> None
327+
328+
Perform a 2D all-to-all-v operation using NVSHMEM, with split information provided on device. In Mixture of Experts models, this operation can be used to dispatch tokens.
329+
330+
:param Tensor input: Input tensor to perform all-to-all on. Must be symmetric.
331+
:param Tensor out: Output tensor to store the result of the all-to-all operation. Must be symmetric.
332+
:param Tensor in_splits: Tensor containing the splits of data to send to each expert. Must be symmetric. Must be of size (group_size * ne,), where ne is the number of experts per rank. The splits are in the unit of elements in the 1st dimension.
333+
:param Tensor out_splits_offsets: Tensor containing the splits and offsets of data received from each peer. Must be symmetric. Must be of size (2, group_size * ne). The rows are (in order): output splits and output offsets.
334+
:param str group_name: Name of the group to perform all-to-all on.
335+
:param int major_align: Optional alignment for the major dimension of the output chunk for each expert. If not provided, the alignment is assumed to be 1. Any alignment adjustment will be reflected in the output offsets.
336+
337+
A 2D AllToAllv shuffle is illustrated below:
338+
(world_size = 2, ne = 2, total number of experts = 4)::
339+
340+
Source: | Rank 0 | Rank 1 |
341+
| c0 | c1 | c2 | c3 | d0 | d1 | d2 | d3 |
342+
343+
Dest : | Rank 0 | Rank 1 |
344+
| c0 | d0 | c1 | d1 | c2 | d2 | c3 | d3 |
345+
346+
where each `c_i` / `d_i` are slices of the `input` tensor, targeting expert
347+
`i`, with length indicated by input splits. That is, the 2D AllToAllv
348+
shuffle achieves a transpose from rank-major order at input to expert-major
349+
order at output.
350+
351+
If `major_align` is not 1, the output offsets of c1, c2, c3 will be
352+
up-aligned to this value. For example, if c0 has length 5 and d0 has
353+
length 7 (making a total of 12), and if the `major_align` is set to 16,
354+
the output offset of c1 will be 16. Similar for c2 and c3. This value has
355+
no effect on the offset of the minor dimension, i.e. d0, d1, d2 and d3.
356+
Note: since cutlass does not support empty bins, we set the aligned length
357+
to `major_align` if it is 0. See
358+
https://github.com/pytorch/pytorch/issues/152668.
359+
360+
361+
.. py:function:: all_to_all_vdev_2d_offset(Tensor input, Tensor out, Tensor in_splits_offsets, Tensor out_splits_offsets, str group_name) -> None
362+
363+
Perform a 2D AllToAllv shuffle operation, with input split and offset
364+
information provided on device. The input offsets are not required to be
365+
exact prefix sum of the input splits, i.e. paddings are allowed between the
366+
split chunks. The paddings, however, will not be transferred to peer
367+
ranks.
368+
369+
In Mixture of Experts models, this operation can be used to combine tokens
370+
processed by experts on parallel ranks. This operation can be viewed as an
371+
"reverse" operation to the `all_to_all_vdev_2d` operation (which shuffles
372+
tokens to experts).
373+
374+
:param Tensor input: Input tensor to perform all-to-all on. Must be symmetric.
375+
:param Tensor out: Output tensor to store the result of the all-to-all operation. Must be symmetric.
376+
:param Tensor in_splits_offsets: Tensor containing the splits and offsets of data to send to each expert. Must be symmetric. Must be of size (2, group_size * ne), where `ne` is the number of experts. The rows are (in order): input splits and input offsets. The splits are in the unit of elements in the 1st dimension.
377+
:param Tensor out_splits_offsets: Tensor containing the splits and offsets of data received from each peer. Must be symmetric. Must be of size (2, group_size * ne). The rows are (in order): output splits and output offsets.
378+
:param str group_name: Name of the group to perform all-to-all on.
379+
380+
```

0 commit comments

Comments
 (0)