Skip to content

Commit d0f107e

Browse files
authored
[TRTLLM-5966][feat] Helix: add full MLA support for Helix (#8104)
Signed-off-by: Matthias Jouanneaux <[email protected]>
1 parent 5e6f1bc commit d0f107e

File tree

11 files changed

+1709
-110
lines changed

11 files changed

+1709
-110
lines changed
Lines changed: 243 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,243 @@
1+
/*
2+
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
#include "tensorrt_llm/common/envUtils.h"
18+
#include "tensorrt_llm/kernels/helixKernels.h"
19+
20+
#include <cstdint>
21+
#include <cstdio>
22+
23+
#include <cooperative_groups.h>
24+
#include <cuda_bf16.h>
25+
#include <cuda_fp16.h>
26+
#include <cuda_runtime.h>
27+
28+
using namespace tensorrt_llm::common;
29+
30+
namespace cg = cooperative_groups;
31+
32+
namespace tensorrt_llm
33+
{
34+
namespace kernels
35+
{
36+
static constexpr int WARP_SIZE = 32;
37+
38+
// Utility: warp-level corrected sum
39+
template <int N>
40+
__device__ inline void warpReduceCorrectedSum(float (&correctedVal)[N], float (&maxVal)[N], float (&sumVal)[N])
41+
{
42+
float warp_max = maxVal[0];
43+
#pragma unroll
44+
for (int nn = 1; nn < N; ++nn)
45+
warp_max = fmaxf(warp_max, maxVal[nn]);
46+
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ == 1000 && defined(__CUDA_ARCH_FEAT_SM100_ALL))
47+
asm("redux.sync.max.f32 %0, %1, 0xffffffff;\n" : "=f"(warp_max) : "f"(warp_max));
48+
#else
49+
#pragma unroll
50+
for (int offset = 1; offset < WARP_SIZE; offset *= 2)
51+
warp_max = fmaxf(warp_max, __shfl_xor_sync(0xffffffff, warp_max, offset));
52+
#endif
53+
float global_sum = 0.F;
54+
float corrected_max_exp[N];
55+
#pragma unroll
56+
for (int nn = 0; nn < N; ++nn)
57+
{
58+
corrected_max_exp[nn] = sumVal[nn] * expf(maxVal[nn] - warp_max);
59+
global_sum += corrected_max_exp[nn];
60+
}
61+
#pragma unroll
62+
for (int offset = 1; offset < WARP_SIZE; offset *= 2)
63+
global_sum += __shfl_xor_sync(0xffffffff, global_sum, offset);
64+
auto norm = 1.F / global_sum;
65+
#pragma unroll
66+
for (int nn = 0; nn < N; ++nn)
67+
correctedVal[nn] = corrected_max_exp[nn] * norm;
68+
}
69+
70+
static constexpr int MAX_CP_VAL_PER_THREAD = 8;
71+
static constexpr int MAX_CP = WARP_SIZE * MAX_CP_VAL_PER_THREAD;
72+
static constexpr int BYTES_O_PER_THREAD = 16;
73+
static constexpr int NUM_PRE_LOAD = 8;
74+
75+
// Kernel: fused helix post-processing
76+
// output: [num_tokens, num_heads * kv_lora_rank] (half)
77+
// gathered_o: [cp_size, num_tokens, num_heads * kv_lora_rank] (half)
78+
// gathered_stats: [cp_size, num_tokens, num_heads, 2] (fp32)
79+
// note: we explicitly avoid using restrict here, to avoid getting ld.global.nc
80+
// which may have longer latency
81+
template <typename T>
82+
__global__ void helix_postprocess_kernel(
83+
T* output, T const* gathered_o, float2 const* gathered_stats, int cp_size, int kv_lora_rank)
84+
{
85+
// Each block processes one (token, head)
86+
// gridDim.x: num_tokens, gridDim.y: num_heads
87+
// there are two separate types of warps:
88+
// warp 0 calculates the correction values (one per cp_size)
89+
// all other warps pre-load the gathered_o elements for the current token/head
90+
// and once warp 0 is done, all other warps can start accumulating the output
91+
static constexpr int NUM_O_PER_THREAD = BYTES_O_PER_THREAD / sizeof(T);
92+
93+
int tok_idx = blockIdx.x;
94+
int head_idx = blockIdx.y;
95+
int num_tokens = gridDim.x;
96+
int num_heads = gridDim.y;
97+
98+
int const cp_size_aligned = ((cp_size + NUM_PRE_LOAD - 1) / NUM_PRE_LOAD) * NUM_PRE_LOAD;
99+
__shared__ float smem_correction[MAX_CP];
100+
101+
int lane_idx = threadIdx.x % WARP_SIZE;
102+
int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / WARP_SIZE, 0);
103+
// here we have to wait for memory operations of the previous kernel to complete
104+
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
105+
cudaGridDependencySynchronize();
106+
#endif
107+
108+
if (warp_idx == 0)
109+
{
110+
// the warp collectively calculates the correction values
111+
float max_values[MAX_CP_VAL_PER_THREAD];
112+
float sum_values[MAX_CP_VAL_PER_THREAD];
113+
#pragma unroll
114+
for (int cp_val_idx = 0; cp_val_idx < MAX_CP_VAL_PER_THREAD; ++cp_val_idx)
115+
{
116+
auto cp_idx = cp_val_idx * WARP_SIZE + lane_idx;
117+
auto stats_offset = cp_idx * num_tokens * num_heads + tok_idx * num_heads + head_idx;
118+
float2 stats = cp_idx < cp_size ? gathered_stats[stats_offset] : make_float2(-INFINITY, 0.F);
119+
max_values[cp_val_idx] = stats.x;
120+
sum_values[cp_val_idx] = stats.y;
121+
}
122+
float corrected_values[MAX_CP_VAL_PER_THREAD];
123+
warpReduceCorrectedSum(corrected_values, max_values, sum_values);
124+
#pragma unroll
125+
for (int cp_val_idx = 0; cp_val_idx < MAX_CP_VAL_PER_THREAD; ++cp_val_idx)
126+
{
127+
auto cp_idx = cp_val_idx * WARP_SIZE + lane_idx;
128+
smem_correction[cp_idx] = corrected_values[cp_val_idx];
129+
}
130+
cg::this_thread_block().sync();
131+
}
132+
else
133+
{
134+
// all other warps pre-load the gathered_o elements for the current token/head
135+
auto const* gathered_o_off = gathered_o + tok_idx * num_heads * kv_lora_rank + head_idx * kv_lora_rank;
136+
// we subtract WARP_SIZE because first warp is not participating here
137+
gathered_o_off += (threadIdx.x - WARP_SIZE) * NUM_O_PER_THREAD;
138+
float4 const* gathered_o_16b = reinterpret_cast<float4 const*>(gathered_o_off);
139+
auto gathered_16b_stride = (num_tokens * num_heads * kv_lora_rank) / NUM_O_PER_THREAD;
140+
T vals[NUM_PRE_LOAD][NUM_O_PER_THREAD];
141+
#pragma unroll
142+
for (int cp_idx = 0; cp_idx < NUM_PRE_LOAD && cp_idx < cp_size; ++cp_idx)
143+
{
144+
auto val
145+
= cp_idx < cp_size ? gathered_o_16b[cp_idx * gathered_16b_stride] : make_float4(0.F, 0.F, 0.F, 0.F);
146+
*reinterpret_cast<float4*>(vals[cp_idx]) = val;
147+
}
148+
float final_sum[NUM_O_PER_THREAD];
149+
#pragma unroll
150+
for (int o_idx = 0; o_idx < NUM_O_PER_THREAD; ++o_idx)
151+
{
152+
final_sum[o_idx] = 0.F;
153+
}
154+
cg::this_thread_block().sync();
155+
156+
// here we can trigger the dependent kernels to start
157+
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
158+
cudaTriggerProgrammaticLaunchCompletion();
159+
#endif
160+
161+
float corr_vals[NUM_PRE_LOAD];
162+
#pragma unroll
163+
for (int cp_idx = 0; cp_idx < NUM_PRE_LOAD && cp_idx < cp_size; ++cp_idx)
164+
{
165+
corr_vals[cp_idx] = smem_correction[cp_idx];
166+
}
167+
168+
for (int cp_idx_base = NUM_PRE_LOAD; cp_idx_base < cp_size_aligned; cp_idx_base += NUM_PRE_LOAD)
169+
{
170+
#pragma unroll
171+
for (int cp_idx = 0; cp_idx < NUM_PRE_LOAD; ++cp_idx)
172+
{
173+
#pragma unroll
174+
for (int o_idx = 0; o_idx < NUM_O_PER_THREAD; ++o_idx)
175+
{
176+
final_sum[o_idx] += static_cast<float>(vals[cp_idx][o_idx]) * corr_vals[cp_idx];
177+
}
178+
}
179+
#pragma unroll
180+
for (int cp_idx = 0; cp_idx < NUM_PRE_LOAD; ++cp_idx)
181+
{
182+
*reinterpret_cast<float4*>(vals[cp_idx]) = cp_idx_base + cp_idx < cp_size
183+
? gathered_o_16b[(cp_idx_base + cp_idx) * gathered_16b_stride]
184+
: make_float4(0.F, 0.F, 0.F, 0.F);
185+
corr_vals[cp_idx] = cp_idx_base + cp_idx < cp_size ? smem_correction[cp_idx_base + cp_idx] : 0.F;
186+
}
187+
}
188+
#pragma unroll
189+
for (int cp_idx = 0; cp_idx < NUM_PRE_LOAD && cp_idx < cp_size; ++cp_idx)
190+
{
191+
#pragma unroll
192+
for (int o_idx = 0; o_idx < NUM_O_PER_THREAD; ++o_idx)
193+
{
194+
final_sum[o_idx] += static_cast<float>(vals[cp_idx][o_idx]) * corr_vals[cp_idx];
195+
}
196+
}
197+
T output_typed[NUM_O_PER_THREAD];
198+
#pragma unroll
199+
for (int o_idx = 0; o_idx < NUM_O_PER_THREAD; ++o_idx)
200+
{
201+
output_typed[o_idx] = static_cast<T>(final_sum[o_idx]);
202+
}
203+
auto* output_off = output + tok_idx * num_heads * kv_lora_rank + head_idx * kv_lora_rank;
204+
output_off += (threadIdx.x - WARP_SIZE) * NUM_O_PER_THREAD;
205+
*reinterpret_cast<float4*>(output_off) = *reinterpret_cast<float4*>(output_typed);
206+
}
207+
}
208+
209+
template <typename T>
210+
void helixPostProcess(HelixPostProcParams<T> const& params, cudaStream_t stream)
211+
{
212+
// Check that gathered_o is 16-byte aligned
213+
TLLM_CHECK_WITH_INFO(reinterpret_cast<uintptr_t>(params.gathered_o) % 16 == 0,
214+
"gathered_o must be 16-byte aligned for async memcpy");
215+
// Check that kv_lora_rank * sizeof(T) is a multiple of 16
216+
TLLM_CHECK_WITH_INFO((params.kv_lora_rank * sizeof(T)) % 16 == 0,
217+
"kv_lora_rank * sizeof(T) must be a multiple of 16 for async memcpy");
218+
// Check that cp_size is not larger than the max fallback CP size
219+
TLLM_CHECK_WITH_INFO(params.cp_size <= MAX_CP, "cp_size > fallback max CP size");
220+
221+
auto* kernel_instance = &helix_postprocess_kernel<T>;
222+
cudaLaunchConfig_t config;
223+
config.gridDim = dim3(params.num_tokens, params.num_heads);
224+
config.blockDim = WARP_SIZE + params.kv_lora_rank * sizeof(T) / 16;
225+
config.dynamicSmemBytes = 0;
226+
config.stream = stream;
227+
cudaLaunchAttribute attrs[1];
228+
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
229+
attrs[0].val.programmaticStreamSerializationAllowed = tensorrt_llm::common::getEnvEnablePDL();
230+
config.numAttrs = 1;
231+
config.attrs = attrs;
232+
TLLM_CUDA_CHECK(cudaLaunchKernelEx(&config, kernel_instance, params.output, params.gathered_o,
233+
params.gathered_stats, params.cp_size, params.kv_lora_rank));
234+
}
235+
236+
#define INSTANTIATE_POST_PROC(T) \
237+
template void helixPostProcess<T>(HelixPostProcParams<T> const& params, cudaStream_t stream);
238+
239+
INSTANTIATE_POST_PROC(__half);
240+
INSTANTIATE_POST_PROC(__nv_bfloat16);
241+
242+
} // namespace kernels
243+
} // namespace tensorrt_llm
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
/*
2+
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
#pragma once
18+
19+
#include "tensorrt_llm/common/cudaUtils.h"
20+
21+
#include <cstdint>
22+
#include <cuda_bf16.h>
23+
#include <cuda_fp16.h>
24+
#include <cuda_runtime.h>
25+
26+
namespace tensorrt_llm
27+
{
28+
namespace kernels
29+
{
30+
template <typename T>
31+
struct HelixPostProcParams
32+
{
33+
T* output;
34+
T const* gathered_o;
35+
float2 const* gathered_stats;
36+
int cp_size;
37+
int num_tokens;
38+
int num_heads;
39+
int kv_lora_rank;
40+
};
41+
42+
template <typename T>
43+
void helixPostProcess(HelixPostProcParams<T> const& params, cudaStream_t stream);
44+
45+
} // namespace kernels
46+
} // namespace tensorrt_llm

cpp/tensorrt_llm/kernels/mlaKernels.cu

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,7 @@ template <typename T, int BLOCK_SIZE, int K_DIM, int ROPE_DIM, typename KVCacheB
187187
__global__ void applyMLARopeAndAssignQKVKernelOptContext(T* q_ptr, T* q_pe, T* k_ptr, T const* fuse_buf,
188188
KVCacheBuffer kv_cache, int q_pe_ld, int q_pe_stride, float2 const* cos_sin_cache, size_t head_num, int head_size,
189189
int c_k, int* cu_q_seqlens, int32_t const* kv_cache_lengths, uint32_t max_input_seq_len, KvCacheDataType cache_type,
190-
float const* quant_scale_kv, bool absorption_mode)
190+
float const* quant_scale_kv, int32_t const* helix_position_offsets, bool absorption_mode)
191191
{
192192

193193
// Constants.
@@ -237,7 +237,8 @@ __global__ void applyMLARopeAndAssignQKVKernelOptContext(T* q_ptr, T* q_pe, T* k
237237
local_token_idx = std::min(local_token_idx, cache_seq_len - 1);
238238
int const global_token_idx = local_token_idx + global_token_offset;
239239

240-
auto const position_id = local_token_idx;
240+
auto const position_id
241+
= helix_position_offsets ? helix_position_offsets[global_token_idx] : local_token_idx;
241242
float2 const* rotary_coef_cache_buffer
242243
= cos_sin_cache + static_cast<size_t>(ROPE_DIM) * position_id + (head_dim_idx / 2);
243244

@@ -949,7 +950,7 @@ void invokeMLARopeContext(MlaParams<T>& params, KVCacheBuffer kv_cache_buffer, c
949950
params.q_pe, params.k_buf, params.latent_cache, kv_cache_buffer, params.q_pe_ld, params.q_pe_stride,
950951
params.cos_sin_cache, params.head_num, head_size, params.meta.kv_lora_rank, params.cu_q_seqlens,
951952
params.cache_seq_lens, params.max_input_seq_len, params.cache_type, params.quant_scale_kv,
952-
params.absorption_mode);
953+
params.helix_position_offsets, params.absorption_mode);
953954
}
954955

955956
template <typename T>

cpp/tensorrt_llm/thop/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ add_library(
6868
fusedTopkSoftmax.cpp
6969
gatherTreeOp.cpp
7070
groupRmsNormOp.cpp
71+
helixPostProcessOp.cpp
7172
llama4MinLatency.cpp
7273
logitsBitmaskOp.cpp
7374
mambaConv1dOp.cpp

0 commit comments

Comments
 (0)