Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
683 changes: 683 additions & 0 deletions cpp/tensorrt_llm/kernels/helixAllToAll.cu

Large diffs are not rendered by default.

94 changes: 94 additions & 0 deletions cpp/tensorrt_llm/kernels/helixAllToAll.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
/*
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once

#include "tensorrt_llm/common/config.h"

#include <cuda_runtime.h>

#include <cstddef>
#include <cstdint>

TRTLLM_NAMESPACE_BEGIN

namespace kernels
{

struct HelixFieldInfo
{
uint8_t* dataPtr;
int elementCount; // Number of elements (e.g., kv_lora_rank for field 0, 1 for
// field 1)
int elementSize; // Size of each element in bytes (2 for half, 8 for float2)
int stride; // Stride between rows in bytes
};

struct HelixAllToAllParams
{
HelixFieldInfo sendFields[2];
HelixFieldInfo recvFields[2];
int entryCount; // Number of entries per peer rank to process
uint64_t* workspace;
int workspaceStrideInU64;
int cpRank;
int cpSize;
int channelCount; // use 0 to auto-compute
int maxChannelCount;
};

// ============================================================================
// Workspace Management Functions
// ============================================================================

/**
* Compute number of channels for communication based on cpSize.
*
* @param cpSize Number of context parallel ranks
* @param smCount Number of SMs available (0 = auto-detect)
* @return Number of channels to use
*/
int computeHelixMaxChannelCount(int cpSize, int smCount = 0);

/**
* Compute the workspace size required per rank for the all-to-all operation.
*
* @param cpSize Number of context parallel ranks
* @return Size in bytes
*/
size_t computeHelixWorkspaceSizePerRank(int cpSize);

/**
* Initialize workspace memory for a given rank.
* Should be called once during setup.
*
* @param workspace Pointer to workspace memory (per-rank view)
* @param cpSize Number of context parallel ranks
* @param stream CUDA stream for asynchronous operations
*/
void initializeHelixWorkspace(uint64_t* workspace, int cpSize, cudaStream_t stream);

/**
* Launch the helix all-to-all kernel.
*
* @param params Kernel parameters including field info and workspace
* @param allowVariableField1 Whether to allow variable field 1
* @param stream CUDA stream for kernel launch
*/
void launchHelixAllToAll(HelixAllToAllParams const& params, bool allowVariableField1, cudaStream_t stream);

} // namespace kernels

TRTLLM_NAMESPACE_END
6 changes: 6 additions & 0 deletions cpp/tensorrt_llm/nanobind/thop/bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include <nanobind/nanobind.h>
#include <nanobind/stl/optional.h>
#include <nanobind/stl/vector.h>
#include <tensorrt_llm/kernels/helixAllToAll.h>
#include <tensorrt_llm/thop/attentionOp.h>
#include <tensorrt_llm/thop/moeAlltoAllMeta.h>
#include <torch/extension.h>
Expand Down Expand Up @@ -73,5 +74,10 @@ void initBindings(nb::module_& m)
nb::arg("mla_bmm1_scale") = std::nullopt, nb::arg("mla_bmm2_scale") = std::nullopt,
nb::arg("quant_q_buffer") = std::nullopt, "Multi-head attention operation",
nb::call_guard<nb::gil_scoped_release>());

m.def(
"get_helix_workspace_size_per_rank",
[](int cp_size) { return tensorrt_llm::kernels::computeHelixWorkspaceSizePerRank(cp_size); },
nb::arg("cp_size"), "Get helix all-to-all workspace size per rank in bytes");
}
} // namespace tensorrt_llm::nanobind::thop
6 changes: 6 additions & 0 deletions cpp/tensorrt_llm/pybind/thop/bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include <pybind11/functional.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <tensorrt_llm/kernels/helixAllToAll.h>
#include <tensorrt_llm/thop/attentionOp.h>
#include <tensorrt_llm/thop/moeAlltoAllMeta.h>
#include <torch/extension.h>
Expand Down Expand Up @@ -73,5 +74,10 @@ void initBindings(pybind11::module_& m)
py::arg("mla_bmm1_scale") = std::nullopt, py::arg("mla_bmm2_scale") = std::nullopt,
py::arg("quant_q_buffer") = std::nullopt, "Multi-head attention operation",
py::call_guard<py::gil_scoped_release>());

m.def(
"get_helix_workspace_size_per_rank",
[](int cp_size) { return tensorrt_llm::kernels::computeHelixWorkspaceSizePerRank(cp_size); },
py::arg("cp_size"), "Get helix all-to-all workspace size per rank in bytes");
}
} // namespace tensorrt_llm::pybind::thop
158 changes: 149 additions & 9 deletions cpp/tensorrt_llm/thop/alltoallOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,12 @@
*/

#include "tensorrt_llm/common/opUtils.h"
#include "tensorrt_llm/kernels/helixAllToAll.h"
#include "tensorrt_llm/runtime/torchUtils.h"
#include "tensorrt_llm/runtime/utils/mpiUtils.h"
#include "tensorrt_llm/thop/thUtils.h"

#include <NvInferRuntime.h>
#include <c10/cuda/CUDAStream.h>
#include <cassert>
#include <set>
#include <string>
#include <torch/extension.h>
#include <vector>
#if ENABLE_MULTI_DEVICE
#include <nccl.h>
#endif // ENABLE_MULTI_DEVICE

TRTLLM_NAMESPACE_BEGIN

Expand Down Expand Up @@ -119,16 +112,163 @@ std::vector<torch::Tensor> alltoall_helix(
#endif // ENABLE_MULTI_DEVICE
}

/**
* Helix All-to-All operation with two fields.
*
* Input tensors have shape [..., cp_size, kv_lora_rank] for partial_o and [...,
* cp_size, 2] for softmax_stats. The operation exchanges data along the cp_size
* dimension across all ranks.
*
* @param partial_o Field 0 tensor (half precision, shape [..., cp_size,
* kv_lora_rank])
* @param softmax_stats Field 1 tensor (float32, shape [..., cp_size, 2])
* @param workspace Workspace tensor (uint64, strided across ranks)
* @param cp_rank Current context parallel rank
* @param cp_size Total number of context parallel ranks
* @return tuple of (partial_o_out, softmax_stats_out) with same shapes as inputs
*/
std::tuple<torch::Tensor, torch::Tensor> alltoall_helix_native(
torch::Tensor partial_o, torch::Tensor softmax_stats, torch::Tensor workspace, int64_t cp_rank, int64_t cp_size)
{

// Input validation
CHECK_TH_CUDA(partial_o);
CHECK_TH_CUDA(softmax_stats);
CHECK_TH_CUDA(workspace);
CHECK_CONTIGUOUS(partial_o);
CHECK_CONTIGUOUS(softmax_stats);

// Type checks
TORCH_CHECK(partial_o.scalar_type() == at::ScalarType::Half || partial_o.scalar_type() == at::ScalarType::BFloat16,
"partial_o must be half or bfloat16");
CHECK_TYPE(softmax_stats, at::ScalarType::Float);
CHECK_TYPE(workspace, at::ScalarType::UInt64);

// Shape validation
TORCH_CHECK(partial_o.dim() >= 2, "partial_o must have at least 2 dimensions");
TORCH_CHECK(softmax_stats.dim() >= 2, "softmax_stats must have at least 2 dimensions");
TORCH_CHECK(
partial_o.dim() == softmax_stats.dim(), "partial_o and softmax_stats must have same number of dimensions");

// Get dimensions
int kv_lora_rank = partial_o.size(-1);
TORCH_CHECK(partial_o.size(-2) == cp_size && softmax_stats.size(-2) == cp_size,
"partial_o/softmax_stats second-to-last dimension must equal cp_size");
TORCH_CHECK(softmax_stats.size(-1) % 2 == 0 && softmax_stats.size(-1) >= 2,
"softmax_stats last dimension must be divisible by 2 (float2)");
bool allowVariableField1 = softmax_stats.size(-1) > 2;

// Check that leading dimensions match
for (int i = 0; i < partial_o.dim() - 2; i++)
{
TORCH_CHECK(partial_o.size(i) == softmax_stats.size(i),
"partial_o and softmax_stats must have matching dimensions except last two");
}
TORCH_CHECK(partial_o.size(-1) * partial_o.element_size() % 16 == 0, "partial_o must be aligned to 16 bytes");

TORCH_CHECK(workspace.dim() == 2, "workspace must be 2D (strided across ranks)");
TORCH_CHECK(workspace.size(0) == cp_size, "workspace must have cp_size rows");

// Calculate entry count (product of all dimensions before cp_size)
// This is the number of entries to process per peer rank
int entry_count = 1;
for (int i = 0; i < partial_o.dim() - 2; i++)
{
entry_count *= partial_o.size(i);
}

// Reshape to 3D: [entry_count, cp_size, feature_dim]
torch::Tensor partial_o_3d = partial_o.reshape({entry_count, cp_size, kv_lora_rank});
torch::Tensor softmax_stats_3d = softmax_stats.reshape({entry_count, cp_size, softmax_stats.size(-1)});

// Allocate output tensors (same shape as input)
torch::Tensor partial_o_out = torch::empty_like(partial_o);
torch::Tensor softmax_stats_out = torch::empty_like(softmax_stats);

torch::Tensor partial_o_out_3d = partial_o_out.reshape({entry_count, cp_size, kv_lora_rank});
torch::Tensor softmax_stats_out_3d = softmax_stats_out.reshape({entry_count, cp_size, softmax_stats.size(-1)});

// Setup parameters
tensorrt_llm::kernels::HelixAllToAllParams params;

// Field 0 (variable size half)
params.sendFields[0].dataPtr = reinterpret_cast<uint8_t*>(partial_o_3d.data_ptr());
params.sendFields[0].elementCount = kv_lora_rank;
params.sendFields[0].elementSize = partial_o.element_size();
params.sendFields[0].stride = partial_o_3d.stride(1) * partial_o.element_size();

params.recvFields[0].dataPtr = reinterpret_cast<uint8_t*>(partial_o_out_3d.data_ptr());
params.recvFields[0].elementCount = kv_lora_rank;
params.recvFields[0].elementSize = partial_o.element_size();
params.recvFields[0].stride = partial_o_out_3d.stride(1) * partial_o.element_size();

// Field 1 (single float2)
params.sendFields[1].dataPtr = reinterpret_cast<uint8_t*>(softmax_stats_3d.data_ptr<float>());
params.sendFields[1].elementCount = softmax_stats.size(-1);
params.sendFields[1].elementSize = softmax_stats.element_size();
params.sendFields[1].stride = softmax_stats_3d.stride(1) * softmax_stats.element_size();

params.recvFields[1].dataPtr = reinterpret_cast<uint8_t*>(softmax_stats_out_3d.data_ptr<float>());
params.recvFields[1].elementCount = softmax_stats.size(-1);
params.recvFields[1].elementSize = softmax_stats.element_size();
params.recvFields[1].stride = softmax_stats_out_3d.stride(1) * softmax_stats.element_size();

// Entry count and workspace
params.entryCount = entry_count;
params.workspace = workspace.data_ptr<uint64_t>();
params.workspaceStrideInU64 = workspace.stride(0);

// CP info
params.cpRank = cp_rank;
params.cpSize = cp_size;
params.channelCount = 0; // auto-compute
params.maxChannelCount = tensorrt_llm::kernels::computeHelixMaxChannelCount(cp_size);

// Launch kernel
auto stream = at::cuda::getCurrentCUDAStream();
tensorrt_llm::kernels::launchHelixAllToAll(params, allowVariableField1, stream);

return std::make_tuple(partial_o_out, softmax_stats_out);
}

/**
* Initialize workspace for helix all-to-all
*/
void initialize_helix_workspace(torch::Tensor workspace, int64_t cp_rank, int64_t cp_size)
{
CHECK_TH_CUDA(workspace);
CHECK_TYPE(workspace, at::ScalarType::UInt64);
TORCH_CHECK(workspace.dim() == 2, "workspace must be 2D");
TORCH_CHECK(workspace.size(0) == cp_size, "workspace must have cp_size rows");
TORCH_CHECK(cp_rank >= 0 && cp_rank < cp_size, "cp_rank must be in [0, cp_size)");

auto stream = at::cuda::getCurrentCUDAStream();
uint64_t* global_workspace_ptr = workspace.data_ptr<uint64_t>();
uint64_t* local_workspace_ptr = workspace[cp_rank].data_ptr<uint64_t>();
TORCH_CHECK(local_workspace_ptr == global_workspace_ptr + cp_rank * workspace.stride(0),
"local_workspace_ptr must be at the correct offset in the global "
"workspace");
tensorrt_llm::kernels::initializeHelixWorkspace(local_workspace_ptr, cp_size, stream);
}

} // namespace torch_ext

TRTLLM_NAMESPACE_END

TORCH_LIBRARY_FRAGMENT(trtllm, m)
{
m.def("alltoall_helix(Tensor[] input_list, int[] group, int? num_lists) -> Tensor[]");
m.def(
"alltoall_helix_native(Tensor partial_o, Tensor softmax_stats, Tensor(a!) workspace, int "
"cp_rank, int cp_size) -> (Tensor, Tensor)");
m.def(
"initialize_helix_workspace(Tensor(a!) workspace, int cp_rank, int cp_size) "
"-> ()");
}

TORCH_LIBRARY_IMPL(trtllm, CUDA, m)
{
m.impl("alltoall_helix", &tensorrt_llm::torch_ext::alltoall_helix);
m.impl("alltoall_helix_native", &tensorrt_llm::torch_ext::alltoall_helix_native);
m.impl("initialize_helix_workspace", &tensorrt_llm::torch_ext::initialize_helix_workspace);
}
Loading