Skip to content

[Example] One shot all reduce #245

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions examples/README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ Other Operations
- :doc:`cross_entropy.py <cross_entropy>`: Cross entropy loss function
- :doc:`embedding.py <embedding>`: Embedding lookup operation
- :doc:`all_gather_matmul.py <all_gather_matmul>`: All-gather operation followed by matrix multiplication
- :doc:`all_reduce.py <all_reduce>`: All-reduce operation (one-shot)

.. toctree::
:maxdepth: 2
Expand All @@ -57,6 +58,7 @@ Other Operations

add
all_gather_matmul
all_reduce
attention
bmm
concatenate
Expand Down
263 changes: 263 additions & 0 deletions examples/all_reduce.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,263 @@
"""
One-Shot All-Reduce Example
========================================
This example demonstrates how to implement a one-shot pulling all-reduce operation
using Helion and PyTorch's distributed capabilities. It includes a Helion kernel
demonstrating how to do cross-device synchronization using symmetric memory signal pads
and access symmetric memory tensor resident on peer devices.
"""

# %%
# Imports
# -------
from __future__ import annotations

import os

import torch
import torch.distributed as dist
import torch.distributed._symmetric_memory as symm_mem
from torch.utils.cpp_extension import load_inline

import helion
import helion.language as hl

# %%
# Work around before symm mem natively supports extract dev_ptrs as tensors: from_blob
from_blob_cpp = """
#include <cuda.h>
#include <cuda_runtime.h>
#include <iostream>


at::Tensor from_blob(uint64_t data_ptr, c10::IntArrayRef sizes, py::object dtype) {

at::Tensor tensor = at::for_blob((void*)data_ptr, sizes)
.deleter([](void *ptr) {
;
})
.options(at::device(at::kCUDA).dtype(((THPDtype*)dtype.ptr())->scalar_type))
.make_tensor();

return tensor;
}
"""

cpp_mod = load_inline(
"cpp_mod", cpp_sources=from_blob_cpp, with_cuda=True, functions=["from_blob"]
)


def dev_array_to_tensor_short(
dev_array_ptr: int, shape: tuple[int], dtype: torch.dtype, device: torch.device
) -> torch.Tensor:
"""
Convert a device array pointer to a PyTorch tensor.

This is a workaround function that creates a PyTorch tensor from a raw device pointer
using the C++ extension. It's used to interface with symmetric memory device pointers
before native support is available.

Args:
dev_array_ptr: Raw device pointer as integer
shape: Shape of the tensor to create
dtype: PyTorch data type for the tensor
device: Target device for the tensor

Returns:
PyTorch tensor created from the device pointer
"""
return cpp_mod.from_blob(dev_array_ptr, shape, dtype) # pyright: ignore[reportAttributeAccessIssue]


# %%
# One Shot All-Reduce Kernel Implementation
# ----------------------------------------
@helion.jit(
config=helion.Config(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we able to autotune this yet?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No. Unfortunately.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What are the blockers?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We'll need support for collaborative autotuning on multiple torchrun initiated processes.

I have event-based benchmarking infra ready in #393 (autotuner/benchmarker) which reports timing results on process 0.

We need to:

  1. Make sure all processes benchmark the same configs in the same order. (Is there any randomization in the autotuning process?)
  2. Use the event based benchmarker when in torchrun env inside autotuner. (easy)
  3. Communicate results from process 0 to all processes, OR process 0 makes a decision and communicate the optimal config to all processes. (Through caching?)

block_sizes=[8192],
num_warps=32,
),
static_shapes=True,
)
def one_shot_all_reduce_kernel(
signal_pad_addrs: torch.Tensor,
local_signal_pad: torch.Tensor,
a_shared_tuple: tuple[torch.Tensor, ...],
my_rank: hl.constexpr,
) -> torch.Tensor:
"""
Helion JIT-compiled kernel for one-shot all-reduce operation.

This kernel implements a distributed all-reduce using symmetric memory and signal pads
for cross-device synchronization. It performs element-wise summation across all devices
in the distributed group using tiled computation for memory efficiency.

Args:
signal_pad_addrs: Tensor containing addresses of signal pads for all devices
local_signal_pad: Local signal pad for synchronization
a_shared_tuple: Tuple of shared tensors from all devices in the group
my_rank: Current device's rank in the distributed group

Returns:
Tensor containing the all-reduced result (sum across all devices)
"""
_, world_size = local_signal_pad.size()
world_size = hl.specialize(world_size)
out = torch.empty_like(a_shared_tuple[0])
N = out.size(0)

for tile_n in hl.tile(N):
# Sync all devices through signal_pad to make sure
# all previous writes to the shared tensor are visible
ptr_tile = signal_pad_addrs[:]
stack_signalpad = hl.stacktensor_like(local_signal_pad, ptr_tile)
hl.signal(
stack_signalpad,
[tile_n.id, my_rank],
signal=1,
wait_for=0,
scope="sys",
hasPreviousMemAccess=False,
)

for world in hl.tile(world_size, block_size=world_size):
hl.wait(
local_signal_pad,
[tile_n.id, world],
signal=1,
update=0,
scope="sys",
)

acc = hl.zeros(
[tile_n], dtype=a_shared_tuple[0].dtype, device=local_signal_pad.device
)

for a in a_shared_tuple:
acc += a[tile_n]

out[tile_n] = acc

# Sync all devices through signal_pad to make sure our writes to shared
# tensor are visible to subsequent kernels.
hl.signal(
stack_signalpad, [tile_n.id, my_rank], signal=1, wait_for=0, scope="sys"
)

for world in hl.tile(world_size, block_size=world_size):
hl.wait(
local_signal_pad,
[tile_n.id, world],
signal=1,
update=0,
scope="sys",
hasSubsequentMemAccess=False,
)
return out


# %%
# Attract tensors from symmetric memory handler
# ----------------------------------------
def helion_one_shot_all_reduce(a_shared: torch.Tensor) -> torch.Tensor:
"""
Prepares symmetric memory tensors for Helion one-shot all-reduce kernel.
Tracks shared tensors as tuple of tensors, and/or dev_ptrs tensors.

Args:
a_shared: Input tensor to be all-reduced across all devices

Returns:
Tensor containing the all-reduced result (sum across all devices)
"""
assert dist.group.WORLD is not None

symm_mem_hdl = symm_mem.rendezvous(a_shared, group=dist.group.WORLD)

a_shared_tuple = tuple(
[
symm_mem_hdl.get_buffer(i, tuple(a_shared.shape), a_shared.dtype)
for i in range(symm_mem_hdl.world_size)
]
)

local_signal_pad = symm_mem_hdl.get_signal_pad(
symm_mem_hdl.rank, dtype=torch.int32
).view(-1, symm_mem_hdl.world_size)

signal_pad_addrs = dev_array_to_tensor_short(
symm_mem_hdl.signal_pad_ptrs_dev,
(symm_mem_hdl.world_size,),
dtype=torch.uint64,
device=a_shared.device,
)

return one_shot_all_reduce_kernel(
signal_pad_addrs,
local_signal_pad,
a_shared_tuple,
my_rank=symm_mem_hdl.rank,
)


# %%
# Testing Function
# ----------------------------------------
def test(N: int, device: torch.device, dtype: torch.dtype) -> None:
"""
Test the Helion all-reduce implementation against PyTorch's reference implementation.
Args:
N: Total number of elements to test (will be divided by world_size per device)
device: CUDA device to run the test on
dtype: Data type for the test tensors
"""
dist_group = dist.group.WORLD
assert dist_group is not None

world_size = dist.get_world_size()
a_shared = symm_mem.empty(N // world_size, dtype=dtype, device=device).normal_()

a_shared_clone = symm_mem.empty(
a_shared.shape,
dtype=a_shared.dtype,
device=a_shared.device,
)
symm_mem.rendezvous(a_shared_clone, dist_group.group_name)
a_shared_clone.copy_(a_shared)

a_out = helion_one_shot_all_reduce(a_shared)

gloden_o = torch.ops.symm_mem.one_shot_all_reduce(
a_shared_clone, "sum", dist_group.group_name
)

torch.testing.assert_close(a_out, gloden_o, rtol=1e-1, atol=1e-1)


def main() -> None:
"""
Main entry point for the all-reduce example.

Sets up the distributed environment, initializes CUDA devices, and runs the
all-reduce test, and then clean up.
"""
rank = int(os.environ["LOCAL_RANK"])
torch.manual_seed(42 + rank)
device = torch.device(f"cuda:{rank}")
torch.cuda.set_device(device)
dist.init_process_group("nccl")
test(16384, device, torch.bfloat16)

dist.destroy_process_group()


if __name__ == "__main__":
"""
Run with:
torchrun \
--nnodes 1 --nproc-per-node 8 \
--rdzv-backend c10d --rdzv-endpoint localhost:0 \
--no_python python3 examples/all_reduce.py
"""
main()
29 changes: 23 additions & 6 deletions helion/language/creation_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,11 @@
__all__ = ["arange", "full", "zeros"]


def zeros(shape: list[object], dtype: torch.dtype = torch.float32) -> torch.Tensor:
def zeros(
shape: list[object],
dtype: torch.dtype = torch.float32,
device: torch.device | None = None,
) -> torch.Tensor:
"""
Return a device-tensor filled with zeros.

Expand Down Expand Up @@ -55,12 +59,17 @@ def process_kernel(input: torch.Tensor) -> torch.Tensor:
- :func:`~helion.language.full`: For filling with arbitrary values
- :func:`~helion.language.arange`: For creating sequences
"""
return full(shape, 0.0 if dtype.is_floating_point else 0, dtype=dtype)
return full(
shape, 0.0 if dtype.is_floating_point else 0, dtype=dtype, device=device
)


@_decorators.api(tiles_as_sizes=True)
def full(
shape: list[object], value: float, dtype: torch.dtype = torch.float32
shape: list[object],
value: float,
dtype: torch.dtype = torch.float32,
device: torch.device | None = None,
) -> torch.Tensor:
"""
Create a device-tensor filled with a specified value.
Expand Down Expand Up @@ -104,6 +113,7 @@ def _full_fake(
shape: list[int | torch.SymInt],
value: float,
dtype: torch.dtype = torch.float32,
device: torch.device | None = None,
) -> torch.Tensor:
if not isinstance(shape, (list, tuple)):
raise TypeError(f"Expected list[SymInt], got {type(shape).__name__}")
Expand All @@ -112,7 +122,7 @@ def _full_fake(
return torch.empty(
[*shape],
dtype=dtype,
device=env.device,
device=env.device if device is None else device,
)


Expand Down Expand Up @@ -150,6 +160,7 @@ def _(
shape: list[int | RefTile],
value: float,
dtype: torch.dtype = torch.float32,
device: torch.device | None = None,
) -> torch.Tensor:
processed_shape = []
for s in shape:
Expand All @@ -158,12 +169,18 @@ def _(
else:
processed_shape.append(s)
env = CompileEnvironment.current()
return torch.full(processed_shape, value, dtype=dtype, device=env.device)
return torch.full(
processed_shape,
value,
dtype=dtype,
device=env.device if device is None else device,
)


def arange(
*args: int,
dtype: torch.dtype | None = None,
device: torch.device | None = None,
**kwargs: object,
) -> torch.Tensor:
"""
Expand Down Expand Up @@ -192,5 +209,5 @@ def arange(
*args,
**kwargs,
dtype=dtype,
device=env.device,
device=env.device if device is None else device,
)
6 changes: 3 additions & 3 deletions test/test_type_propagation.expected
Original file line number Diff line number Diff line change
Expand Up @@ -495,7 +495,7 @@ def root_graph_0():
# File: .../basic_kernels.py:40 in hl_full_usage, code: tmp = hl.full(tile, 1, dtype=x.dtype)
block_size_0: "Sym(u0)" = helion_language__tracing_ops__get_symnode('block_size_0')
block_size_1: "Sym(u1)" = helion_language__tracing_ops__get_symnode('block_size_1')
tmp: "i32[u0, u1]" = helion_language_creation_ops_full((block_size_0, block_size_1), 1, torch.int32)
tmp: "i32[u0, u1]" = helion_language_creation_ops_full((block_size_0, block_size_1), 1, torch.int32, None)

# File: .../basic_kernels.py:41 in hl_full_usage, code: tmp += x[tile]
x: "i32[s77, s27]" = helion_language__tracing_ops__host_tensor('x')
Expand Down Expand Up @@ -553,7 +553,7 @@ def root_graph_0():
# File: .../basic_kernels.py:29 in hl_zeros_usage, code: tmp = hl.zeros(tile, dtype=x.dtype)
block_size_0: "Sym(u0)" = helion_language__tracing_ops__get_symnode('block_size_0')
block_size_1: "Sym(u1)" = helion_language__tracing_ops__get_symnode('block_size_1')
tmp: "i32[u0, u1]" = helion_language_creation_ops_full((block_size_0, block_size_1), 0, torch.int32)
tmp: "i32[u0, u1]" = helion_language_creation_ops_full((block_size_0, block_size_1), 0, torch.int32, None)

# File: .../basic_kernels.py:30 in hl_zeros_usage, code: tmp += x[tile]
x: "i32[s77, s27]" = helion_language__tracing_ops__host_tensor('x')
Expand Down Expand Up @@ -679,7 +679,7 @@ def root_graph_1():
# File: .../matmul.py:52 in matmul, code: acc = hl.zeros([tile_m, tile_n], dtype=torch.float32)
block_size_0: "Sym(u0)" = helion_language__tracing_ops__get_symnode('block_size_0')
block_size_1: "Sym(u1)" = helion_language__tracing_ops__get_symnode('block_size_1')
acc: "f32[u0, u1]" = helion_language_creation_ops_full([block_size_0, block_size_1], 0.0, torch.float32)
acc: "f32[u0, u1]" = helion_language_creation_ops_full([block_size_0, block_size_1], 0.0, torch.float32, None)

# File: .../matmul.py:53 in matmul, code: for tile_k in hl.tile(k):
_for_loop = helion_language__tracing_ops__for_loop(0, [0], [512], [acc])
Expand Down
Loading