Skip to content

Commit cb2cbfc

Browse files
brb-nvJunyiXu-nv
authored andcommitted
[TRTLLM-9493][feat] Custom AllToAll for helix parallelism (NVIDIA#9986)
Signed-off-by: Balaram Buddharaju <169953907+brb-nv@users.noreply.github.com>
1 parent e31275d commit cb2cbfc

File tree

14 files changed

+1242
-107
lines changed

14 files changed

+1242
-107
lines changed

cpp/tensorrt_llm/kernels/helixAllToAll.cu

Lines changed: 683 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
/*
2+
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
#pragma once
17+
18+
#include "tensorrt_llm/common/config.h"
19+
20+
#include <cuda_runtime.h>
21+
22+
#include <cstddef>
23+
#include <cstdint>
24+
25+
TRTLLM_NAMESPACE_BEGIN
26+
27+
namespace kernels
28+
{
29+
30+
struct HelixFieldInfo
31+
{
32+
uint8_t* dataPtr;
33+
int elementCount; // Number of elements (e.g., kv_lora_rank for field 0, 1 for
34+
// field 1)
35+
int elementSize; // Size of each element in bytes (2 for half, 8 for float2)
36+
int stride; // Stride between rows in bytes
37+
};
38+
39+
struct HelixAllToAllParams
40+
{
41+
HelixFieldInfo sendFields[2];
42+
HelixFieldInfo recvFields[2];
43+
int entryCount; // Number of entries per peer rank to process
44+
uint64_t* workspace;
45+
int workspaceStrideInU64;
46+
int cpRank;
47+
int cpSize;
48+
int channelCount; // use 0 to auto-compute
49+
int maxChannelCount;
50+
};
51+
52+
// ============================================================================
53+
// Workspace Management Functions
54+
// ============================================================================
55+
56+
/**
57+
* Compute number of channels for communication based on cpSize.
58+
*
59+
* @param cpSize Number of context parallel ranks
60+
* @param smCount Number of SMs available (0 = auto-detect)
61+
* @return Number of channels to use
62+
*/
63+
int computeHelixMaxChannelCount(int cpSize, int smCount = 0);
64+
65+
/**
66+
* Compute the workspace size required per rank for the all-to-all operation.
67+
*
68+
* @param cpSize Number of context parallel ranks
69+
* @return Size in bytes
70+
*/
71+
size_t computeHelixWorkspaceSizePerRank(int cpSize);
72+
73+
/**
74+
* Initialize workspace memory for a given rank.
75+
* Should be called once during setup.
76+
*
77+
* @param workspace Pointer to workspace memory (per-rank view)
78+
* @param cpSize Number of context parallel ranks
79+
* @param stream CUDA stream for asynchronous operations
80+
*/
81+
void initializeHelixWorkspace(uint64_t* workspace, int cpSize, cudaStream_t stream);
82+
83+
/**
84+
* Launch the helix all-to-all kernel.
85+
*
86+
* @param params Kernel parameters including field info and workspace
87+
* @param allowVariableField1 Whether to allow variable field 1
88+
* @param stream CUDA stream for kernel launch
89+
*/
90+
void launchHelixAllToAll(HelixAllToAllParams const& params, bool allowVariableField1, cudaStream_t stream);
91+
92+
} // namespace kernels
93+
94+
TRTLLM_NAMESPACE_END

cpp/tensorrt_llm/nanobind/thop/bindings.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include <nanobind/nanobind.h>
1919
#include <nanobind/stl/optional.h>
2020
#include <nanobind/stl/vector.h>
21+
#include <tensorrt_llm/kernels/helixAllToAll.h>
2122
#include <tensorrt_llm/thop/attentionOp.h>
2223
#include <tensorrt_llm/thop/moeAlltoAllMeta.h>
2324
#include <torch/extension.h>
@@ -73,5 +74,10 @@ void initBindings(nb::module_& m)
7374
nb::arg("mla_bmm1_scale") = std::nullopt, nb::arg("mla_bmm2_scale") = std::nullopt,
7475
nb::arg("quant_q_buffer") = std::nullopt, "Multi-head attention operation",
7576
nb::call_guard<nb::gil_scoped_release>());
77+
78+
m.def(
79+
"get_helix_workspace_size_per_rank",
80+
[](int cp_size) { return tensorrt_llm::kernels::computeHelixWorkspaceSizePerRank(cp_size); },
81+
nb::arg("cp_size"), "Get helix all-to-all workspace size per rank in bytes");
7682
}
7783
} // namespace tensorrt_llm::nanobind::thop

cpp/tensorrt_llm/pybind/thop/bindings.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include <pybind11/functional.h>
1919
#include <pybind11/pybind11.h>
2020
#include <pybind11/stl.h>
21+
#include <tensorrt_llm/kernels/helixAllToAll.h>
2122
#include <tensorrt_llm/thop/attentionOp.h>
2223
#include <tensorrt_llm/thop/moeAlltoAllMeta.h>
2324
#include <torch/extension.h>
@@ -73,5 +74,10 @@ void initBindings(pybind11::module_& m)
7374
py::arg("mla_bmm1_scale") = std::nullopt, py::arg("mla_bmm2_scale") = std::nullopt,
7475
py::arg("quant_q_buffer") = std::nullopt, "Multi-head attention operation",
7576
py::call_guard<py::gil_scoped_release>());
77+
78+
m.def(
79+
"get_helix_workspace_size_per_rank",
80+
[](int cp_size) { return tensorrt_llm::kernels::computeHelixWorkspaceSizePerRank(cp_size); },
81+
py::arg("cp_size"), "Get helix all-to-all workspace size per rank in bytes");
7682
}
7783
} // namespace tensorrt_llm::pybind::thop

cpp/tensorrt_llm/thop/alltoallOp.cpp

Lines changed: 149 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,19 +16,12 @@
1616
*/
1717

1818
#include "tensorrt_llm/common/opUtils.h"
19+
#include "tensorrt_llm/kernels/helixAllToAll.h"
1920
#include "tensorrt_llm/runtime/torchUtils.h"
2021
#include "tensorrt_llm/runtime/utils/mpiUtils.h"
22+
#include "tensorrt_llm/thop/thUtils.h"
2123

22-
#include <NvInferRuntime.h>
23-
#include <c10/cuda/CUDAStream.h>
24-
#include <cassert>
25-
#include <set>
26-
#include <string>
27-
#include <torch/extension.h>
2824
#include <vector>
29-
#if ENABLE_MULTI_DEVICE
30-
#include <nccl.h>
31-
#endif // ENABLE_MULTI_DEVICE
3225

3326
TRTLLM_NAMESPACE_BEGIN
3427

@@ -119,16 +112,163 @@ std::vector<torch::Tensor> alltoall_helix(
119112
#endif // ENABLE_MULTI_DEVICE
120113
}
121114

115+
/**
116+
* Helix All-to-All operation with two fields.
117+
*
118+
* Input tensors have shape [..., cp_size, kv_lora_rank] for partial_o and [...,
119+
* cp_size, 2] for softmax_stats. The operation exchanges data along the cp_size
120+
* dimension across all ranks.
121+
*
122+
* @param partial_o Field 0 tensor (half precision, shape [..., cp_size,
123+
* kv_lora_rank])
124+
* @param softmax_stats Field 1 tensor (float32, shape [..., cp_size, 2])
125+
* @param workspace Workspace tensor (uint64, strided across ranks)
126+
* @param cp_rank Current context parallel rank
127+
* @param cp_size Total number of context parallel ranks
128+
* @return tuple of (partial_o_out, softmax_stats_out) with same shapes as inputs
129+
*/
130+
std::tuple<torch::Tensor, torch::Tensor> alltoall_helix_native(
131+
torch::Tensor partial_o, torch::Tensor softmax_stats, torch::Tensor workspace, int64_t cp_rank, int64_t cp_size)
132+
{
133+
134+
// Input validation
135+
CHECK_TH_CUDA(partial_o);
136+
CHECK_TH_CUDA(softmax_stats);
137+
CHECK_TH_CUDA(workspace);
138+
CHECK_CONTIGUOUS(partial_o);
139+
CHECK_CONTIGUOUS(softmax_stats);
140+
141+
// Type checks
142+
TORCH_CHECK(partial_o.scalar_type() == at::ScalarType::Half || partial_o.scalar_type() == at::ScalarType::BFloat16,
143+
"partial_o must be half or bfloat16");
144+
CHECK_TYPE(softmax_stats, at::ScalarType::Float);
145+
CHECK_TYPE(workspace, at::ScalarType::UInt64);
146+
147+
// Shape validation
148+
TORCH_CHECK(partial_o.dim() >= 2, "partial_o must have at least 2 dimensions");
149+
TORCH_CHECK(softmax_stats.dim() >= 2, "softmax_stats must have at least 2 dimensions");
150+
TORCH_CHECK(
151+
partial_o.dim() == softmax_stats.dim(), "partial_o and softmax_stats must have same number of dimensions");
152+
153+
// Get dimensions
154+
int kv_lora_rank = partial_o.size(-1);
155+
TORCH_CHECK(partial_o.size(-2) == cp_size && softmax_stats.size(-2) == cp_size,
156+
"partial_o/softmax_stats second-to-last dimension must equal cp_size");
157+
TORCH_CHECK(softmax_stats.size(-1) % 2 == 0 && softmax_stats.size(-1) >= 2,
158+
"softmax_stats last dimension must be divisible by 2 (float2)");
159+
bool allowVariableField1 = softmax_stats.size(-1) > 2;
160+
161+
// Check that leading dimensions match
162+
for (int i = 0; i < partial_o.dim() - 2; i++)
163+
{
164+
TORCH_CHECK(partial_o.size(i) == softmax_stats.size(i),
165+
"partial_o and softmax_stats must have matching dimensions except last two");
166+
}
167+
TORCH_CHECK(partial_o.size(-1) * partial_o.element_size() % 16 == 0, "partial_o must be aligned to 16 bytes");
168+
169+
TORCH_CHECK(workspace.dim() == 2, "workspace must be 2D (strided across ranks)");
170+
TORCH_CHECK(workspace.size(0) == cp_size, "workspace must have cp_size rows");
171+
172+
// Calculate entry count (product of all dimensions before cp_size)
173+
// This is the number of entries to process per peer rank
174+
int entry_count = 1;
175+
for (int i = 0; i < partial_o.dim() - 2; i++)
176+
{
177+
entry_count *= partial_o.size(i);
178+
}
179+
180+
// Reshape to 3D: [entry_count, cp_size, feature_dim]
181+
torch::Tensor partial_o_3d = partial_o.reshape({entry_count, cp_size, kv_lora_rank});
182+
torch::Tensor softmax_stats_3d = softmax_stats.reshape({entry_count, cp_size, softmax_stats.size(-1)});
183+
184+
// Allocate output tensors (same shape as input)
185+
torch::Tensor partial_o_out = torch::empty_like(partial_o);
186+
torch::Tensor softmax_stats_out = torch::empty_like(softmax_stats);
187+
188+
torch::Tensor partial_o_out_3d = partial_o_out.reshape({entry_count, cp_size, kv_lora_rank});
189+
torch::Tensor softmax_stats_out_3d = softmax_stats_out.reshape({entry_count, cp_size, softmax_stats.size(-1)});
190+
191+
// Setup parameters
192+
tensorrt_llm::kernels::HelixAllToAllParams params;
193+
194+
// Field 0 (variable size half)
195+
params.sendFields[0].dataPtr = reinterpret_cast<uint8_t*>(partial_o_3d.data_ptr());
196+
params.sendFields[0].elementCount = kv_lora_rank;
197+
params.sendFields[0].elementSize = partial_o.element_size();
198+
params.sendFields[0].stride = partial_o_3d.stride(1) * partial_o.element_size();
199+
200+
params.recvFields[0].dataPtr = reinterpret_cast<uint8_t*>(partial_o_out_3d.data_ptr());
201+
params.recvFields[0].elementCount = kv_lora_rank;
202+
params.recvFields[0].elementSize = partial_o.element_size();
203+
params.recvFields[0].stride = partial_o_out_3d.stride(1) * partial_o.element_size();
204+
205+
// Field 1 (single float2)
206+
params.sendFields[1].dataPtr = reinterpret_cast<uint8_t*>(softmax_stats_3d.data_ptr<float>());
207+
params.sendFields[1].elementCount = softmax_stats.size(-1);
208+
params.sendFields[1].elementSize = softmax_stats.element_size();
209+
params.sendFields[1].stride = softmax_stats_3d.stride(1) * softmax_stats.element_size();
210+
211+
params.recvFields[1].dataPtr = reinterpret_cast<uint8_t*>(softmax_stats_out_3d.data_ptr<float>());
212+
params.recvFields[1].elementCount = softmax_stats.size(-1);
213+
params.recvFields[1].elementSize = softmax_stats.element_size();
214+
params.recvFields[1].stride = softmax_stats_out_3d.stride(1) * softmax_stats.element_size();
215+
216+
// Entry count and workspace
217+
params.entryCount = entry_count;
218+
params.workspace = workspace.data_ptr<uint64_t>();
219+
params.workspaceStrideInU64 = workspace.stride(0);
220+
221+
// CP info
222+
params.cpRank = cp_rank;
223+
params.cpSize = cp_size;
224+
params.channelCount = 0; // auto-compute
225+
params.maxChannelCount = tensorrt_llm::kernels::computeHelixMaxChannelCount(cp_size);
226+
227+
// Launch kernel
228+
auto stream = at::cuda::getCurrentCUDAStream();
229+
tensorrt_llm::kernels::launchHelixAllToAll(params, allowVariableField1, stream);
230+
231+
return std::make_tuple(partial_o_out, softmax_stats_out);
232+
}
233+
234+
/**
235+
* Initialize workspace for helix all-to-all
236+
*/
237+
void initialize_helix_workspace(torch::Tensor workspace, int64_t cp_rank, int64_t cp_size)
238+
{
239+
CHECK_TH_CUDA(workspace);
240+
CHECK_TYPE(workspace, at::ScalarType::UInt64);
241+
TORCH_CHECK(workspace.dim() == 2, "workspace must be 2D");
242+
TORCH_CHECK(workspace.size(0) == cp_size, "workspace must have cp_size rows");
243+
TORCH_CHECK(cp_rank >= 0 && cp_rank < cp_size, "cp_rank must be in [0, cp_size)");
244+
245+
auto stream = at::cuda::getCurrentCUDAStream();
246+
uint64_t* global_workspace_ptr = workspace.data_ptr<uint64_t>();
247+
uint64_t* local_workspace_ptr = workspace[cp_rank].data_ptr<uint64_t>();
248+
TORCH_CHECK(local_workspace_ptr == global_workspace_ptr + cp_rank * workspace.stride(0),
249+
"local_workspace_ptr must be at the correct offset in the global "
250+
"workspace");
251+
tensorrt_llm::kernels::initializeHelixWorkspace(local_workspace_ptr, cp_size, stream);
252+
}
253+
122254
} // namespace torch_ext
123255

124256
TRTLLM_NAMESPACE_END
125257

126258
TORCH_LIBRARY_FRAGMENT(trtllm, m)
127259
{
128260
m.def("alltoall_helix(Tensor[] input_list, int[] group, int? num_lists) -> Tensor[]");
261+
m.def(
262+
"alltoall_helix_native(Tensor partial_o, Tensor softmax_stats, Tensor(a!) workspace, int "
263+
"cp_rank, int cp_size) -> (Tensor, Tensor)");
264+
m.def(
265+
"initialize_helix_workspace(Tensor(a!) workspace, int cp_rank, int cp_size) "
266+
"-> ()");
129267
}
130268

131269
TORCH_LIBRARY_IMPL(trtllm, CUDA, m)
132270
{
133271
m.impl("alltoall_helix", &tensorrt_llm::torch_ext::alltoall_helix);
272+
m.impl("alltoall_helix_native", &tensorrt_llm::torch_ext::alltoall_helix_native);
273+
m.impl("initialize_helix_workspace", &tensorrt_llm::torch_ext::initialize_helix_workspace);
134274
}

0 commit comments

Comments
 (0)