Skip to content

Commit 30959b0

Browse files
committed
[Example] One shot all reduce
stack-info: PR: #245, branch: joydddd/stack/12
1 parent 4393413 commit 30959b0

File tree

4 files changed

+271
-8
lines changed

4 files changed

+271
-8
lines changed

examples/README.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ Other Operations
4848
- ``cross_entropy.py``: Cross entropy loss function
4949
- ``embedding.py``: Embedding lookup operation
5050
- ``all_gather_matmul.py``: All-gather operation followed by matrix multiplication
51+
- ``all_reduce.py``: All-reduce operation (one-shot)
5152

5253
.. toctree::
5354
:maxdepth: 2
@@ -56,6 +57,7 @@ Other Operations
5657

5758
add
5859
all_gather_matmul
60+
all_reduce
5961
attention
6062
bmm
6163
concatenate

examples/all_reduce.py

Lines changed: 250 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,250 @@
1+
"""
2+
One-Shot All-Reduce Example
3+
========================================
4+
This example demonstrates how to implement a one-shot pulling all-reduce operation
5+
using Helion and PyTorch's distributed capabilities. It includes a Helion kernel
6+
demonstrating how to do cross-device synchronization using symmetric memory signal pads
7+
and access symmetric memory tensor resident on peer devices.
8+
"""
9+
10+
# %%
11+
# Imports
12+
# -------
13+
from __future__ import annotations
14+
15+
import os
16+
17+
import torch
18+
import torch.distributed as dist
19+
import torch.distributed._symmetric_memory as symm_mem
20+
from torch.utils.cpp_extension import load_inline
21+
22+
import helion
23+
import helion.language as hl
24+
25+
# %%
26+
# TODO(joydddd): work around before symm mem natively supports extract dev_ptrs as tensors: from_blob
27+
from_blob_cpp = """
28+
#include <cuda.h>
29+
#include <cuda_runtime.h>
30+
#include <iostream>
31+
32+
33+
at::Tensor from_blob(uint64_t data_ptr, c10::IntArrayRef sizes, py::object dtype) {
34+
35+
at::Tensor tensor = at::for_blob((void*)data_ptr, sizes)
36+
.deleter([](void *ptr) {
37+
;
38+
})
39+
.options(at::device(at::kCUDA).dtype(((THPDtype*)dtype.ptr())->scalar_type))
40+
.make_tensor();
41+
42+
return tensor;
43+
}
44+
"""
45+
46+
cpp_mod = load_inline(
47+
"cpp_mod", cpp_sources=from_blob_cpp, with_cuda=True, functions=["from_blob"]
48+
)
49+
50+
51+
def dev_array_to_tensor_short(
52+
dev_array_ptr: int, shape: tuple[int], dtype: torch.dtype, device: torch.device
53+
) -> torch.Tensor:
54+
"""
55+
Convert a device array pointer to a PyTorch tensor.
56+
57+
This is a workaround function that creates a PyTorch tensor from a raw device pointer
58+
using the C++ extension. It's used to interface with symmetric memory device pointers
59+
before native support is available.
60+
61+
Args:
62+
dev_array_ptr: Raw device pointer as integer
63+
shape: Shape of the tensor to create
64+
dtype: PyTorch data type for the tensor
65+
device: Target device for the tensor
66+
67+
Returns:
68+
PyTorch tensor created from the device pointer
69+
"""
70+
return cpp_mod.from_blob(dev_array_ptr, shape, dtype) # pyright: ignore[reportAttributeAccessIssue]
71+
72+
73+
@helion.jit(
74+
config=helion.Config(
75+
block_sizes=[8192],
76+
num_warps=32,
77+
),
78+
static_shapes=True,
79+
)
80+
def one_shot_all_reduce_kernel(
81+
signal_pad_addrs: torch.Tensor,
82+
local_signal_pad: torch.Tensor,
83+
a_shared_tuple: tuple[torch.Tensor, ...],
84+
my_rank: hl.constexpr,
85+
) -> torch.Tensor:
86+
"""
87+
Helion JIT-compiled kernel for one-shot all-reduce operation.
88+
89+
This kernel implements a distributed all-reduce using symmetric memory and signal pads
90+
for cross-device synchronization. It performs element-wise summation across all devices
91+
in the distributed group using tiled computation for memory efficiency.
92+
93+
Args:
94+
signal_pad_addrs: Tensor containing addresses of signal pads for all devices
95+
local_signal_pad: Local signal pad for synchronization
96+
a_shared_tuple: Tuple of shared tensors from all devices in the group
97+
my_rank: Current device's rank in the distributed group
98+
99+
Returns:
100+
Tensor containing the all-reduced result (sum across all devices)
101+
"""
102+
_, world_size = local_signal_pad.size()
103+
world_size = hl.specialize(world_size)
104+
out = torch.empty_like(a_shared_tuple[0])
105+
N = out.size(0)
106+
107+
for tile_n in hl.tile(N):
108+
ptr_tile = signal_pad_addrs[:]
109+
stack_signalpad = hl.stacktensor_like(local_signal_pad, ptr_tile)
110+
hl.signal(
111+
stack_signalpad,
112+
[tile_n.id, my_rank],
113+
signal=1,
114+
wait_for=0,
115+
scope="sys",
116+
hasPreviousMemAccess=False,
117+
)
118+
119+
for world in hl.tile(world_size, block_size=world_size):
120+
hl.wait(
121+
local_signal_pad,
122+
[tile_n.id, world],
123+
signal=1,
124+
update=0,
125+
scope="sys",
126+
)
127+
128+
acc = hl.zeros(
129+
[tile_n], dtype=a_shared_tuple[0].dtype, device=local_signal_pad.device
130+
)
131+
132+
for a in a_shared_tuple:
133+
acc += a[tile_n]
134+
135+
out[tile_n] = acc
136+
137+
hl.signal(
138+
stack_signalpad, [tile_n.id, my_rank], signal=1, wait_for=0, scope="sys"
139+
)
140+
141+
for world in hl.tile(world_size, block_size=world_size):
142+
hl.wait(
143+
local_signal_pad,
144+
[tile_n.id, world],
145+
signal=1,
146+
update=0,
147+
scope="sys",
148+
hasSubsequentMemAccess=False,
149+
)
150+
return out
151+
152+
153+
def helion_one_shot_all_reduce(a_shared: torch.Tensor) -> torch.Tensor:
154+
"""
155+
Prepares symmetric memory tensors for Helion one-shot all-reduce kernel.
156+
Tracks shared tensors as tuple of tensors, and/or dev_ptrs tensors.
157+
158+
Args:
159+
a_shared: Input tensor to be all-reduced across all devices
160+
161+
Returns:
162+
Tensor containing the all-reduced result (sum across all devices)
163+
"""
164+
assert dist.group.WORLD is not None
165+
166+
symm_mem_hdl = symm_mem.rendezvous(a_shared, group=dist.group.WORLD)
167+
168+
a_shared_tuple = tuple(
169+
[
170+
symm_mem_hdl.get_buffer(i, tuple(a_shared.shape), a_shared.dtype)
171+
for i in range(symm_mem_hdl.world_size)
172+
]
173+
)
174+
175+
local_signal_pad = symm_mem_hdl.get_signal_pad(
176+
symm_mem_hdl.rank, dtype=torch.int32
177+
).view(-1, symm_mem_hdl.world_size)
178+
179+
signal_pad_addrs = dev_array_to_tensor_short(
180+
symm_mem_hdl.signal_pad_ptrs_dev,
181+
(symm_mem_hdl.world_size,),
182+
dtype=torch.uint64,
183+
device=a_shared.device,
184+
)
185+
186+
return one_shot_all_reduce_kernel(
187+
signal_pad_addrs,
188+
local_signal_pad,
189+
a_shared_tuple,
190+
my_rank=symm_mem_hdl.rank,
191+
)
192+
193+
194+
def test(N: int, device: torch.device, dtype: torch.dtype) -> None:
195+
"""
196+
Test the Helion all-reduce implementation against PyTorch's reference implementation.
197+
Args:
198+
N: Total number of elements to test (will be divided by world_size per device)
199+
device: CUDA device to run the test on
200+
dtype: Data type for the test tensors
201+
"""
202+
dist_group = dist.group.WORLD
203+
assert dist_group is not None
204+
205+
world_size = dist.get_world_size()
206+
a_shared = symm_mem.empty(N // world_size, dtype=dtype, device=device).normal_()
207+
208+
a_shared_clone = symm_mem.empty(
209+
a_shared.shape,
210+
dtype=a_shared.dtype,
211+
device=a_shared.device,
212+
)
213+
symm_mem.rendezvous(a_shared_clone, dist_group.group_name)
214+
a_shared_clone.copy_(a_shared)
215+
216+
a_out = helion_one_shot_all_reduce(a_shared)
217+
218+
gloden_o = torch.ops.symm_mem.one_shot_all_reduce(
219+
a_shared_clone, "sum", dist_group.group_name
220+
)
221+
222+
torch.testing.assert_close(a_out, gloden_o, rtol=1e-1, atol=1e-1)
223+
224+
225+
def main() -> None:
226+
"""
227+
Main entry point for the all-reduce example.
228+
229+
Sets up the distributed environment, initializes CUDA devices, and runs the
230+
all-reduce test, and then clean up.
231+
"""
232+
rank = int(os.environ["LOCAL_RANK"])
233+
torch.manual_seed(42 + rank)
234+
device = torch.device(f"cuda:{rank}")
235+
torch.cuda.set_device(device)
236+
dist.init_process_group("nccl")
237+
test(16384, device, torch.bfloat16)
238+
239+
dist.destroy_process_group()
240+
241+
242+
if __name__ == "__main__":
243+
"""
244+
Run with:
245+
torchrun \
246+
--nnodes 1 --nproc-per-node 8 \
247+
--rdzv-backend c10d --rdzv-endpoint localhost:0 \
248+
--no_python python3 examples/all_reduce.py
249+
"""
250+
main()

helion/language/creation_ops.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,11 @@
1919
__all__ = ["arange", "full", "zeros"]
2020

2121

22-
def zeros(shape: list[object], dtype: torch.dtype = torch.float32) -> torch.Tensor:
22+
def zeros(
23+
shape: list[object],
24+
dtype: torch.dtype = torch.float32,
25+
device: torch.device | None = None,
26+
) -> torch.Tensor:
2327
"""
2428
Return a device-tensor filled with zeros.
2529
@@ -55,12 +59,17 @@ def process_kernel(input: torch.Tensor) -> torch.Tensor:
5559
- :func:`~helion.language.full`: For filling with arbitrary values
5660
- :func:`~helion.language.arange`: For creating sequences
5761
"""
58-
return full(shape, 0.0 if dtype.is_floating_point else 0, dtype=dtype)
62+
return full(
63+
shape, 0.0 if dtype.is_floating_point else 0, dtype=dtype, device=device
64+
)
5965

6066

6167
@_decorators.api(tiles_as_sizes=True)
6268
def full(
63-
shape: list[object], value: float, dtype: torch.dtype = torch.float32
69+
shape: list[object],
70+
value: float,
71+
dtype: torch.dtype = torch.float32,
72+
device: torch.device | None = None,
6473
) -> torch.Tensor:
6574
"""
6675
Create a device-tensor filled with a specified value.
@@ -104,6 +113,7 @@ def _full_fake(
104113
shape: list[int | torch.SymInt],
105114
value: float,
106115
dtype: torch.dtype = torch.float32,
116+
device: torch.device | None = None,
107117
) -> torch.Tensor:
108118
if not isinstance(shape, (list, tuple)):
109119
raise TypeError(f"Expected list[SymInt], got {type(shape).__name__}")
@@ -112,7 +122,7 @@ def _full_fake(
112122
return torch.empty(
113123
[*shape],
114124
dtype=dtype,
115-
device=env.device,
125+
device=env.device if device is None else device,
116126
)
117127

118128

@@ -164,6 +174,7 @@ def _(
164174
def arange(
165175
*args: int,
166176
dtype: torch.dtype | None = None,
177+
device: torch.device | None = None,
167178
**kwargs: object,
168179
) -> torch.Tensor:
169180
"""
@@ -192,5 +203,5 @@ def arange(
192203
*args,
193204
**kwargs,
194205
dtype=dtype,
195-
device=env.device,
206+
device=env.device if device is None else device,
196207
)

test/test_type_propagation.expected

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -495,7 +495,7 @@ def root_graph_0():
495495
# File: .../basic_kernels.py:40 in hl_full_usage, code: tmp = hl.full(tile, 1, dtype=x.dtype)
496496
block_size_0: "Sym(u0)" = helion_language__tracing_ops__get_symnode('block_size_0')
497497
block_size_1: "Sym(u1)" = helion_language__tracing_ops__get_symnode('block_size_1')
498-
tmp: "i32[u0, u1]" = helion_language_creation_ops_full((block_size_0, block_size_1), 1, torch.int32)
498+
tmp: "i32[u0, u1]" = helion_language_creation_ops_full((block_size_0, block_size_1), 1, torch.int32, None)
499499

500500
# File: .../basic_kernels.py:41 in hl_full_usage, code: tmp += x[tile]
501501
x: "i32[s77, s27]" = helion_language__tracing_ops__host_tensor('x')
@@ -553,7 +553,7 @@ def root_graph_0():
553553
# File: .../basic_kernels.py:29 in hl_zeros_usage, code: tmp = hl.zeros(tile, dtype=x.dtype)
554554
block_size_0: "Sym(u0)" = helion_language__tracing_ops__get_symnode('block_size_0')
555555
block_size_1: "Sym(u1)" = helion_language__tracing_ops__get_symnode('block_size_1')
556-
tmp: "i32[u0, u1]" = helion_language_creation_ops_full((block_size_0, block_size_1), 0, torch.int32)
556+
tmp: "i32[u0, u1]" = helion_language_creation_ops_full((block_size_0, block_size_1), 0, torch.int32, None)
557557

558558
# File: .../basic_kernels.py:30 in hl_zeros_usage, code: tmp += x[tile]
559559
x: "i32[s77, s27]" = helion_language__tracing_ops__host_tensor('x')
@@ -679,7 +679,7 @@ def root_graph_1():
679679
# File: .../matmul.py:52 in matmul, code: acc = hl.zeros([tile_m, tile_n], dtype=torch.float32)
680680
block_size_0: "Sym(u0)" = helion_language__tracing_ops__get_symnode('block_size_0')
681681
block_size_1: "Sym(u1)" = helion_language__tracing_ops__get_symnode('block_size_1')
682-
acc: "f32[u0, u1]" = helion_language_creation_ops_full([block_size_0, block_size_1], 0.0, torch.float32)
682+
acc: "f32[u0, u1]" = helion_language_creation_ops_full([block_size_0, block_size_1], 0.0, torch.float32, None)
683683

684684
# File: .../matmul.py:53 in matmul, code: for tile_k in hl.tile(k):
685685
_for_loop = helion_language__tracing_ops__for_loop(0, [0], [512], [acc])

0 commit comments

Comments
 (0)