Skip to content

Commit ec93b60

Browse files
committed
One shot all reduce Example
stack-info: PR: #245, branch: joydddd/stack/12
1 parent dc63692 commit ec93b60

File tree

3 files changed

+202
-8
lines changed

3 files changed

+202
-8
lines changed

examples/all_reduce.py

Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,183 @@
1+
from __future__ import annotations
2+
3+
import os
4+
5+
import torch
6+
import torch.distributed as dist
7+
import torch.distributed._symmetric_memory as symm_mem
8+
9+
# Symmemtric Memory Helpers
10+
from torch.utils.cpp_extension import load_inline
11+
12+
import helion
13+
import helion.language as hl
14+
15+
from_blob_cpp = """
16+
#include <cuda.h>
17+
#include <cuda_runtime.h>
18+
#include <iostream>
19+
20+
21+
at::Tensor from_blob(uint64_t data_ptr, c10::IntArrayRef sizes, py::object dtype) {
22+
23+
at::Tensor tensor = at::for_blob((void*)data_ptr, sizes)
24+
.deleter([](void *ptr) {
25+
;
26+
})
27+
.options(at::device(at::kCUDA).dtype(((THPDtype*)dtype.ptr())->scalar_type))
28+
.make_tensor();
29+
30+
return tensor;
31+
}
32+
"""
33+
34+
cpp_mod = load_inline(
35+
"cpp_mod", cpp_sources=from_blob_cpp, with_cuda=True, functions=["from_blob"]
36+
)
37+
38+
39+
def dev_array_to_tensor_short(
40+
dev_array_ptr: int, shape: tuple[int], dtype: torch.dtype, device: torch.device
41+
) -> torch.Tensor:
42+
return cpp_mod.from_blob(dev_array_ptr, shape, dtype) # pyright: ignore[reportAttributeAccessIssue]
43+
44+
45+
@helion.jit(
46+
config=helion.Config(
47+
block_sizes=[8192],
48+
num_warps=32,
49+
),
50+
static_shapes=True,
51+
)
52+
def one_shot_all_reduce_kernel(
53+
signal_pad_addrs: torch.Tensor,
54+
local_signal_pad: torch.Tensor,
55+
a_shared_tuple: tuple[torch.Tensor, ...],
56+
my_rank: hl.constexpr,
57+
) -> torch.Tensor:
58+
_, world_size = local_signal_pad.size()
59+
world_size = hl.specialize(world_size)
60+
out = torch.empty_like(a_shared_tuple[0])
61+
N = out.size(0)
62+
63+
for tile_n in hl.tile(N):
64+
ptr_tile = signal_pad_addrs[:]
65+
stack_signalpad = hl.stacktensor_like(local_signal_pad, ptr_tile)
66+
hl.signal(
67+
stack_signalpad,
68+
[tile_n.id, my_rank],
69+
signal=1,
70+
wait_for=0,
71+
scope="sys",
72+
hasPreviousMemAccess=False,
73+
)
74+
75+
for world in hl.tile(world_size, block_size=world_size):
76+
hl.wait(
77+
local_signal_pad,
78+
[tile_n.id, world],
79+
signal=1,
80+
update=0,
81+
scope="sys",
82+
)
83+
84+
acc = hl.zeros(
85+
[tile_n], dtype=a_shared_tuple[0].dtype, device=local_signal_pad.device
86+
)
87+
88+
for a in a_shared_tuple:
89+
acc += a[tile_n]
90+
91+
out[tile_n] = acc
92+
93+
hl.signal(
94+
stack_signalpad, [tile_n.id, my_rank], signal=1, wait_for=0, scope="sys"
95+
)
96+
97+
for world in hl.tile(world_size, block_size=world_size):
98+
hl.wait(
99+
local_signal_pad,
100+
[tile_n.id, world],
101+
signal=1,
102+
update=0,
103+
scope="sys",
104+
hasSubsequentMemAccess=False,
105+
)
106+
return out
107+
108+
109+
def helion_one_shot_all_reduce(a_shared: torch.Tensor) -> torch.Tensor:
110+
assert dist.group.WORLD is not None
111+
112+
symm_mem_hdl = symm_mem.rendezvous(a_shared, group=dist.group.WORLD)
113+
114+
a_shared_tuple = tuple(
115+
[
116+
symm_mem_hdl.get_buffer(i, tuple(a_shared.shape), a_shared.dtype)
117+
for i in range(symm_mem_hdl.world_size)
118+
]
119+
)
120+
121+
local_signal_pad = symm_mem_hdl.get_signal_pad(
122+
symm_mem_hdl.rank, dtype=torch.int32
123+
).view(-1, symm_mem_hdl.world_size)
124+
125+
signal_pad_addrs = dev_array_to_tensor_short(
126+
symm_mem_hdl.signal_pad_ptrs_dev,
127+
(symm_mem_hdl.world_size,),
128+
dtype=torch.uint64,
129+
device=a_shared.device,
130+
)
131+
132+
return one_shot_all_reduce_kernel(
133+
signal_pad_addrs,
134+
local_signal_pad,
135+
a_shared_tuple,
136+
my_rank=symm_mem_hdl.rank,
137+
)
138+
139+
140+
def test(N: int, device: torch.device, dtype: torch.dtype) -> None:
141+
dist_group = dist.group.WORLD
142+
assert dist_group is not None
143+
144+
world_size = dist.get_world_size()
145+
a_shared = symm_mem.empty(N // world_size, dtype=dtype, device=device).normal_()
146+
147+
a_shared_clone = symm_mem.empty(
148+
a_shared.shape,
149+
dtype=a_shared.dtype,
150+
device=a_shared.device,
151+
)
152+
symm_mem.rendezvous(a_shared_clone, dist_group.group_name)
153+
a_shared_clone.copy_(a_shared)
154+
155+
a_out = helion_one_shot_all_reduce(a_shared)
156+
157+
gloden_o = torch.ops.symm_mem.one_shot_all_reduce(
158+
a_shared_clone, "sum", dist_group.group_name
159+
)
160+
161+
torch.testing.assert_close(a_out, gloden_o, rtol=1e-1, atol=1e-1)
162+
163+
164+
def main() -> None:
165+
rank = int(os.environ["LOCAL_RANK"])
166+
torch.manual_seed(42 + rank)
167+
device = torch.device(f"cuda:{rank}")
168+
torch.cuda.set_device(device)
169+
dist.init_process_group("nccl")
170+
test(16384, device, torch.bfloat16)
171+
172+
dist.destroy_process_group()
173+
174+
175+
if __name__ == "__main__":
176+
"""
177+
Run with:
178+
torchrun \
179+
--nnodes 1 --nproc-per-node 8 \
180+
--rdzv-backend c10d --rdzv-endpoint localhost:0 \
181+
--no_python python3 examples/all_reduce.py
182+
"""
183+
main()

helion/language/creation_ops.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,11 @@
1919
__all__ = ["arange", "full", "zeros"]
2020

2121

22-
def zeros(shape: list[object], dtype: torch.dtype = torch.float32) -> torch.Tensor:
22+
def zeros(
23+
shape: list[object],
24+
dtype: torch.dtype = torch.float32,
25+
device: torch.device | None = None,
26+
) -> torch.Tensor:
2327
"""
2428
Return a device-tensor filled with zeros.
2529
@@ -55,12 +59,17 @@ def process_kernel(input: torch.Tensor) -> torch.Tensor:
5559
- :func:`~helion.language.full`: For filling with arbitrary values
5660
- :func:`~helion.language.arange`: For creating sequences
5761
"""
58-
return full(shape, 0.0 if dtype.is_floating_point else 0, dtype=dtype)
62+
return full(
63+
shape, 0.0 if dtype.is_floating_point else 0, dtype=dtype, device=device
64+
)
5965

6066

6167
@_decorators.api(tiles_as_sizes=True)
6268
def full(
63-
shape: list[object], value: float, dtype: torch.dtype = torch.float32
69+
shape: list[object],
70+
value: float,
71+
dtype: torch.dtype = torch.float32,
72+
device: torch.device | None = None,
6473
) -> torch.Tensor:
6574
"""
6675
Create a device-tensor filled with a specified value.
@@ -104,6 +113,7 @@ def _full_fake(
104113
shape: list[int | torch.SymInt],
105114
value: float,
106115
dtype: torch.dtype = torch.float32,
116+
device: torch.device | None = None,
107117
) -> torch.Tensor:
108118
if not isinstance(shape, (list, tuple)):
109119
raise TypeError(f"Expected list[SymInt], got {type(shape).__name__}")
@@ -112,7 +122,7 @@ def _full_fake(
112122
return torch.empty(
113123
[*shape],
114124
dtype=dtype,
115-
device=env.device,
125+
device=env.device if device is None else device,
116126
)
117127

118128

@@ -164,6 +174,7 @@ def _(
164174
def arange(
165175
*args: int,
166176
dtype: torch.dtype | None = None,
177+
device: torch.device | None = None,
167178
**kwargs: object,
168179
) -> torch.Tensor:
169180
"""
@@ -192,5 +203,5 @@ def arange(
192203
*args,
193204
**kwargs,
194205
dtype=dtype,
195-
device=env.device,
206+
device=env.device if device is None else device,
196207
)

test/test_type_propagation.expected

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -495,7 +495,7 @@ def root_graph_0():
495495
# File: .../basic_kernels.py:40 in hl_full_usage, code: tmp = hl.full(tile, 1, dtype=x.dtype)
496496
block_size_0: "Sym(u0)" = helion_language__tracing_ops__get_symnode('block_size_0')
497497
block_size_1: "Sym(u1)" = helion_language__tracing_ops__get_symnode('block_size_1')
498-
tmp: "i32[u0, u1]" = helion_language_creation_ops_full((block_size_0, block_size_1), 1, torch.int32)
498+
tmp: "i32[u0, u1]" = helion_language_creation_ops_full((block_size_0, block_size_1), 1, torch.int32, None)
499499

500500
# File: .../basic_kernels.py:41 in hl_full_usage, code: tmp += x[tile]
501501
x: "i32[s77, s27]" = helion_language__tracing_ops__host_tensor('x')
@@ -553,7 +553,7 @@ def root_graph_0():
553553
# File: .../basic_kernels.py:29 in hl_zeros_usage, code: tmp = hl.zeros(tile, dtype=x.dtype)
554554
block_size_0: "Sym(u0)" = helion_language__tracing_ops__get_symnode('block_size_0')
555555
block_size_1: "Sym(u1)" = helion_language__tracing_ops__get_symnode('block_size_1')
556-
tmp: "i32[u0, u1]" = helion_language_creation_ops_full((block_size_0, block_size_1), 0, torch.int32)
556+
tmp: "i32[u0, u1]" = helion_language_creation_ops_full((block_size_0, block_size_1), 0, torch.int32, None)
557557

558558
# File: .../basic_kernels.py:30 in hl_zeros_usage, code: tmp += x[tile]
559559
x: "i32[s77, s27]" = helion_language__tracing_ops__host_tensor('x')
@@ -669,7 +669,7 @@ def root_graph_1():
669669
# File: .../matmul.py:32 in matmul, code: acc = hl.zeros([tile_m, tile_n], dtype=torch.float32)
670670
block_size_0: "Sym(u0)" = helion_language__tracing_ops__get_symnode('block_size_0')
671671
block_size_1: "Sym(u1)" = helion_language__tracing_ops__get_symnode('block_size_1')
672-
acc: "f32[u0, u1]" = helion_language_creation_ops_full([block_size_0, block_size_1], 0.0, torch.float32)
672+
acc: "f32[u0, u1]" = helion_language_creation_ops_full([block_size_0, block_size_1], 0.0, torch.float32, None)
673673

674674
# File: .../matmul.py:33 in matmul, code: for tile_k in hl.tile(k):
675675
_for_loop = helion_language__tracing_ops__for_loop(0, [0], [512], [acc])

0 commit comments

Comments
 (0)