Skip to content

Commit 8000827

Browse files
lfr-0531heyuhhh
authored andcommitted
[TRTLLM-8536][feat] Add the sparse attention framework and one use case--RocketKV support (NVIDIA#8086)
Signed-off-by: Fanrong Li <[email protected]> Signed-off-by: yuhangh <[email protected]> Co-authored-by: yuhangh <[email protected]>
1 parent 963ad46 commit 8000827

File tree

43 files changed

+5149
-209
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

43 files changed

+5149
-209
lines changed

cpp/tensorrt_llm/common/attentionOp.cpp

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#include "tensorrt_llm/kernels/gptKernels.h"
2525
#include "tensorrt_llm/kernels/kvCacheUtils.h"
2626
#include "tensorrt_llm/kernels/multiHeadAttentionCommon.h"
27+
#include "tensorrt_llm/kernels/sparseAttentionKernels.h"
2728
#include "tensorrt_llm/kernels/unfusedAttentionKernels.h"
2829
#include "tensorrt_llm/runtime/iBuffer.h"
2930
#include "tensorrt_llm/runtime/utils/debugUtils.h"
@@ -287,6 +288,9 @@ bool AttentionOp::convertMMHAParamsToXQAParams(tensorrt_llm::kernels::XQAParams&
287288
xqaParams.output_sf = generationsParams.context_buf_sf;
288289
xqaParams.fp4_out_sf_scale = generationsParams.attention_output_sf_scale;
289290
xqaParams.start_token_idx_sf = generationsParams.start_token_idx_sf;
291+
// Parameters for sparse attention
292+
xqaParams.sparse_params = mRuntimeSparseAttentionParams;
293+
xqaParams.use_sparse_attention = useTllmGenSparseAttention();
290294

291295
// Cross attention parameters.
292296
xqaParams.encoder_input_lengths = generationsParams.encoder_input_lengths;
@@ -813,7 +817,7 @@ size_t AttentionOp::getWorkspaceSizeForContext(nvinfer1::DataType type, int32_t
813817
}
814818

815819
size_t AttentionOp::getWorkspaceSizeForGeneration(nvinfer1::DataType type, int32_t max_num_seq,
816-
int32_t max_attention_window_size, int32_t max_num_tokens) const noexcept
820+
int32_t max_attention_window_size, int32_t max_num_tokens, int32_t max_blocks_per_sequence) const noexcept
817821
{
818822
if (max_num_tokens == 0)
819823
{
@@ -909,11 +913,15 @@ size_t AttentionOp::getWorkspaceSizeForGeneration(nvinfer1::DataType type, int32
909913
size_t xqa_workspace_size = 0;
910914
if (mEnableXQA)
911915
{
912-
int const XQA_NUM_BUFFERS = 7;
916+
int const XQA_NUM_BUFFERS = 8;
913917
size_t xqa_workspaces[XQA_NUM_BUFFERS];
914918
size_t const cu_seqlens_size = sizeof(int) * (batch_beam + 1);
915919
size_t const cu_kv_seqlens_size = sizeof(int) * (batch_beam + 1);
916920
size_t const rotary_inv_freq_size = sizeof(float) * batch_beam * mRotaryEmbeddingDim / 2;
921+
// Two workspaces for sparse attention. One for the sequence lengths, and one for kv block offsets.
922+
size_t const sparse_attn_cache_size = useTllmGenSparseAttention()
923+
? sizeof(int) * (batch_beam + batch_beam * 2 * max_blocks_per_sequence) * mNumKVHeads
924+
: 0;
917925
xqa_workspaces[0] = cu_seqlens_size;
918926
xqa_workspaces[1] = cu_kv_seqlens_size;
919927
xqa_workspaces[2] = rotary_inv_freq_size;
@@ -922,7 +930,8 @@ size_t AttentionOp::getWorkspaceSizeForGeneration(nvinfer1::DataType type, int32
922930
// Scales used for trtllm-gen kernels.
923931
xqa_workspaces[4] = sizeof(float) * 2;
924932
xqa_workspaces[5] = sizeof(float);
925-
xqa_workspaces[6] = mXqaDispatcher->getWorkspaceSize(
933+
xqa_workspaces[6] = sparse_attn_cache_size;
934+
xqa_workspaces[7] = mXqaDispatcher->getWorkspaceSize(
926935
std::min<uint32_t>(mSpecDecodingMaxGenerationLength * max_num_seq, max_num_tokens));
927936
xqa_workspace_size
928937
= tc::calculateTotalWorkspaceSize(xqa_workspaces, XQA_NUM_BUFFERS, mXqaDispatcher->getWorkspaceAlignment());
@@ -1647,6 +1656,10 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
16471656
preprocessingParams.spec_decoding_position_offsets = nullptr;
16481657
preprocessingParams.logn_scaling = params.logn_scaling_ptr;
16491658

1659+
// Sparse KV write
1660+
preprocessingParams.sparse_kv_indices = mRuntimeSparseAttentionParams.sparse_kv_indices;
1661+
preprocessingParams.sparse_kv_offsets = mRuntimeSparseAttentionParams.sparse_kv_offsets;
1662+
16501663
// Scalars
16511664
preprocessingParams.batch_size = params.batch_size;
16521665
preprocessingParams.max_input_seq_len = params.input_seq_length;
@@ -1676,6 +1689,8 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
16761689

16771690
preprocessingParams.rotary_vision_start = mVisionStart;
16781691
preprocessingParams.rotary_vision_length = mVisionLength;
1692+
preprocessingParams.is_last_chunk
1693+
= !mAttentionChunkSize.has_value() || (params.input_seq_length == params.max_past_kv_length);
16791694

16801695
{
16811696
std::string const beforeRopeStr = "ctx attention before RoPE at layer " + std::to_string(mLayerIdx);
@@ -1841,6 +1856,12 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
18411856
gatherInBuffer, params, cu_q_seqlens, cu_cp_partial_seqlens, stream);
18421857
sync_check_cuda_error(stream);
18431858
}
1859+
1860+
if (!mIsMLAEnabled) // Only for non-MLA attention
1861+
{
1862+
invokeKvCachePostprocessing(preprocessingParams, stream);
1863+
sync_check_cuda_error(stream);
1864+
}
18441865
}
18451866
else
18461867
{

cpp/tensorrt_llm/common/attentionOp.h

Lines changed: 35 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
#include "tensorrt_llm/kernels/gptKernels.h"
2727
#include "tensorrt_llm/kernels/kvCacheUtils.h"
2828
#include "tensorrt_llm/kernels/mlaKernels.h"
29+
#include "tensorrt_llm/kernels/sparseAttentionKernels.h"
2930
#include "tensorrt_llm/kernels/xqaDispatcher.h"
3031
#include <cassert>
3132
#include <set>
@@ -55,7 +56,7 @@ class AttentionOp
5556
int32_t cross_kv_length = 0, int32_t max_num_tokens = 0) const noexcept;
5657
// total_num_seq is the sum of beam_width for multiple requests
5758
[[nodiscard]] size_t getWorkspaceSizeForGeneration(nvinfer1::DataType type, int32_t total_num_seq,
58-
int32_t max_attention_window_size, int32_t max_num_tokens) const noexcept;
59+
int32_t max_attention_window_size, int32_t max_num_tokens, int32_t max_blocks_per_sequence) const noexcept;
5960

6061
template <typename T>
6162
class EnqueueParams
@@ -156,14 +157,20 @@ class AttentionOp
156157
ss << "max_cyclic_attention_window_size: " << this->max_cyclic_attention_window_size << std::endl;
157158
ss << "can_use_one_more_block: " << (this->can_use_one_more_block ? "true" : "false") << std::endl;
158159
ss << "sink_token_length: " << this->sink_token_length << std::endl;
159-
ss << "context_lengths: "
160-
<< *(runtime::ITensor::wrap((void*) this->context_lengths, nvinfer1::DataType::kINT32,
161-
runtime::ITensor::makeShape({batch_size})))
162-
<< std::endl;
163-
ss << "sequence_lengths: "
164-
<< *(runtime::ITensor::wrap((void*) this->sequence_lengths, nvinfer1::DataType::kINT32,
165-
runtime::ITensor::makeShape({batch_size})))
166-
<< std::endl;
160+
if (this->context_lengths && batch_size > 0)
161+
{
162+
ss << "context_lengths: "
163+
<< *(runtime::ITensor::wrap((void*) this->context_lengths, nvinfer1::DataType::kINT32,
164+
runtime::ITensor::makeShape({batch_size})))
165+
<< std::endl;
166+
}
167+
if (this->sequence_lengths && batch_size > 0)
168+
{
169+
ss << "sequence_lengths: "
170+
<< *(runtime::ITensor::wrap((void*) this->sequence_lengths, nvinfer1::DataType::kINT32,
171+
runtime::ITensor::makeShape({batch_size})))
172+
<< std::endl;
173+
}
167174
ss << "kv_scale_orig_quant: " << this->kv_scale_orig_quant << std::endl;
168175
ss << "kv_scale_quant_orig: " << this->kv_scale_quant_orig << std::endl;
169176
ss << "attention_output_orig_quant: " << this->attention_output_orig_quant << std::endl;
@@ -348,6 +355,16 @@ class AttentionOp
348355
return mIsMLAEnabled;
349356
}
350357

358+
[[nodiscard]] bool useSparseAttention() const
359+
{
360+
return mUseSparseAttention && mPagedKVCache && mEnableXQA;
361+
}
362+
363+
[[nodiscard]] bool useTllmGenSparseAttention() const
364+
{
365+
return mUseTllmGenSparseAttention && useSparseAttention();
366+
}
367+
351368
[[nodiscard]] int smVersion() const
352369
{
353370
return mSM;
@@ -427,6 +444,8 @@ class AttentionOp
427444
bool mIsMLAEnabled = false;
428445
bool mIsGenerationMLA = false;
429446
bool mUseGenFlashMLA = false;
447+
bool mUseSparseAttention = false;
448+
bool mUseTllmGenSparseAttention = false;
430449
tensorrt_llm::kernels::MlaMetaParams mMLAParams;
431450
int mCpSize = 1;
432451
int mCpRank = 0;
@@ -454,6 +473,8 @@ class AttentionOp
454473
// Whether to fuse FP4 quant into attention kernel.
455474
bool mFuseFp4Quant = false;
456475

476+
kernels::SparseAttentionParams mRuntimeSparseAttentionParams;
477+
457478
// This is implementation details which we want to save when serializing, but not expose as
458479
// a plugin field or a constructor parameter
459480
int32_t mNbMultiBlockSemaphores = 0;
@@ -473,10 +494,11 @@ class AttentionOp
473494
mPosShiftEnabled, mPagedContextFMHA, mFP8ContextFMHA, mFP8AttenOutput, mFP8ContextMLA, mFP8GenerationMLA,
474495
mChunkPrefillBufferBatchSize, mDenseContextFMHA, mHasFullAttentionMask, mIsSpecDecodingEnabled,
475496
mUseSpecDecoding, mIsSpecDecTree, mSpecDecodingIsGenerationLengthVariable, mSpecDecodingMaxGenerationLength,
476-
mIsMLAEnabled, mIsGenerationMLA, mUseGenFlashMLA, mMLAParams.data(), mCpSize, mCpRank, mCpGroup,
477-
mNumAttnHeads, mNumAttnKVHeads, mNumKVHeadsOrigin, mAttnTpSize, mAttnTpRank, mAttnCpSize, mAttnCpRank,
478-
mUlyssesMQABroadcast, mEnableContextFMHA, mFMHAForceFP32Acc, mMultiBlockMode, mEnableXQA, mUseKVCache,
479-
mSkipAttn, mFuseFp4Quant, mNbMultiBlockSemaphores, mAttentionChunkSize.value_or(-1));
497+
mIsMLAEnabled, mIsGenerationMLA, mUseGenFlashMLA, mUseSparseAttention, mUseTllmGenSparseAttention,
498+
mMLAParams.data(), mCpSize, mCpRank, mCpGroup, mNumAttnHeads, mNumAttnKVHeads, mNumKVHeadsOrigin,
499+
mAttnTpSize, mAttnTpRank, mAttnCpSize, mAttnCpRank, mUlyssesMQABroadcast, mEnableContextFMHA,
500+
mFMHAForceFP32Acc, mMultiBlockMode, mEnableXQA, mUseKVCache, mSkipAttn, mFuseFp4Quant,
501+
mRuntimeSparseAttentionParams.data(), mNbMultiBlockSemaphores, mAttentionChunkSize.value_or(-1));
480502
};
481503

482504
private:

cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplCommon.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,8 @@ struct XQALaunchParam
233233
float* bmm2_scale_ptr = nullptr;
234234
int32_t* semaphores = nullptr;
235235
void* scratch = nullptr;
236+
void* sparse_kv_block_offsets = nullptr;
237+
int32_t* sparse_seq_lengths = nullptr;
236238
};
237239

238240
// Setup launch params and ioScratch. ioScratch is for RoPE and output type conversion.
@@ -266,6 +268,9 @@ void buildXQALaunchParams(XQALaunchParam<KVCacheBuffer>& launchParams, void*& in
266268
const size_t cu_kv_seqlens_size = sizeof(int) * (batch_beam_size + 1);
267269
const size_t rotary_inv_freq_size = sizeof(float) * batch_beam_size * params.rotary_embedding_dim / 2;
268270
const size_t tokens_info_size = sizeof(int2) * params.total_num_input_tokens;
271+
const size_t kv_block_offsets_size
272+
= sizeof(int) * batch_beam_size * 2 * params.max_blocks_per_sequence * params.num_kv_heads;
273+
const size_t seq_lengths_size = sizeof(int) * batch_beam_size * params.num_kv_heads;
269274
launchParams.cu_seq_lens = reinterpret_cast<int*>(workspace);
270275
workspace = tensorrt_llm::common::nextWorkspacePtrWithAlignment(workspace, cu_seqlens_size);
271276
launchParams.cu_kv_seq_lens = reinterpret_cast<int*>(workspace);
@@ -281,6 +286,14 @@ void buildXQALaunchParams(XQALaunchParam<KVCacheBuffer>& launchParams, void*& in
281286
workspace = tensorrt_llm::common::nextWorkspacePtrWithAlignment(workspace, bmm1_scale_size);
282287
launchParams.bmm2_scale_ptr = reinterpret_cast<float*>(workspace);
283288
workspace = tensorrt_llm::common::nextWorkspacePtrWithAlignment(workspace, bmm2_scale_size);
289+
// Used for block sparse attention
290+
if (params.use_sparse_attention)
291+
{
292+
launchParams.sparse_kv_block_offsets = reinterpret_cast<void*>(workspace);
293+
workspace = tensorrt_llm::common::nextWorkspacePtrWithAlignment(workspace, kv_block_offsets_size);
294+
launchParams.sparse_seq_lengths = reinterpret_cast<int*>(workspace);
295+
workspace = tensorrt_llm::common::nextWorkspacePtrWithAlignment(workspace, seq_lengths_size);
296+
}
284297
inputScratch = workspace;
285298
if (hasOutputScratch)
286299
{

cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/xqaParams.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include "tensorrt_llm/common/quantization.h"
1818
#include "tensorrt_llm/kernels/gptKernels.h"
1919
#include "tensorrt_llm/kernels/multiHeadAttentionCommon.h"
20+
#include "tensorrt_llm/kernels/sparseAttentionKernels.h"
2021

2122
namespace tensorrt_llm
2223
{
@@ -109,6 +110,10 @@ struct XQAParams
109110
// for cross attention
110111
int32_t const* encoder_input_lengths = nullptr;
111112

113+
// sparse attention parameters
114+
SparseAttentionParams sparse_params;
115+
bool use_sparse_attention = false;
116+
112117
cudaStream_t stream = 0;
113118

114119
std::string toString() const
@@ -179,6 +184,8 @@ struct XQAParams
179184
<< "is_fp8_output :" << (is_fp8_output ? "true" : "false") << std ::endl
180185
<< "fp8_out_scale :" << fp8_out_scale << std ::endl
181186
<< "encoder_input_lengths: " << encoder_input_lengths << std::endl
187+
<< "sparse_params: " << sparse_params.toString() << std::endl
188+
<< "use_sparse_attention :" << (use_sparse_attention ? "true" : "false") << std ::endl
182189
<< "stream :" << stream;
183190

184191
return ss.str();
Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
/*
2+
* Copyright (c) 2022-2025, NVIDIA CORPORATION. All rights reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
#include "tensorrt_llm/kernels/sparseAttentionKernels.h"
17+
#include <cub/cub.cuh>
18+
19+
namespace tensorrt_llm
20+
{
21+
namespace kernels
22+
{
23+
template <int THREADS_PER_BLOCK>
24+
__global__ void gatherKvPageOffsetsKernel(
25+
int32_t* output_kv_page_offsets, // [num_head_kv, batch_size, 2, max_num_pages_per_seq]
26+
int32_t* output_seq_lengths, // [num_head_kv, batch_size]
27+
int32_t const* kv_page_offsets, // [batch_size, 2, max_num_pages_per_seq]
28+
int32_t const* seq_lengths, // [batch_size]
29+
SparseAttentionParams const sparse_params, int32_t const batch_size, int32_t const tokens_per_page,
30+
int32_t const max_num_pages_per_seq)
31+
{
32+
// Each CUDA block processes one sequence from the batch for one head.
33+
int32_t const head_idx = blockIdx.x;
34+
int32_t const batch_idx = blockIdx.y;
35+
if (batch_idx >= batch_size)
36+
{
37+
return;
38+
}
39+
40+
// Shared memory for reduction.
41+
__shared__ typename cub::BlockReduce<Pair, THREADS_PER_BLOCK>::TempStorage temp_storage;
42+
43+
// Get the range of sparse indices and the sequence length.
44+
int32_t const start_offset = sparse_params.sparse_attn_offsets[batch_idx];
45+
int32_t const end_offset = sparse_params.sparse_attn_offsets[batch_idx + 1];
46+
int32_t const total_pages = sparse_params.sparse_attn_offsets[batch_size];
47+
int32_t const num_sparse_pages = end_offset - start_offset;
48+
int32_t const original_seq_len = seq_lengths[batch_idx];
49+
50+
// Get global sparse index.
51+
int32_t const sparse_idx_global = head_idx * total_pages + start_offset;
52+
53+
// Get the base memory offset. shape: [batch_size, 2, max_num_pages_per_seq]
54+
size_t const src_base_offset = (size_t) batch_idx * 2 * max_num_pages_per_seq;
55+
size_t const dst_base_offset = (size_t) head_idx * batch_size * 2 * max_num_pages_per_seq + src_base_offset;
56+
57+
// Initialize the local max page index and number of valid pages.
58+
int32_t local_max_page_index = -1;
59+
int32_t local_num_valid_pages = 0;
60+
61+
// Perform the gather operation.
62+
for (int32_t i = threadIdx.x; i < num_sparse_pages; i += blockDim.x)
63+
{
64+
// Get the source idx and offset.
65+
int32_t const src_idx = sparse_params.sparse_attn_indices[sparse_idx_global + i];
66+
if (src_idx < 0)
67+
{
68+
continue;
69+
}
70+
71+
// Update the local max page index.
72+
local_max_page_index = max(local_max_page_index, src_idx);
73+
local_num_valid_pages++;
74+
75+
// Get the source and destination offsets.
76+
size_t const src_offset_dim0 = src_base_offset + 0 * max_num_pages_per_seq + src_idx;
77+
size_t const src_offset_dim1 = src_base_offset + 1 * max_num_pages_per_seq + src_idx;
78+
size_t const dst_offset_dim0 = dst_base_offset + 0 * max_num_pages_per_seq + i;
79+
size_t const dst_offset_dim1 = dst_base_offset + 1 * max_num_pages_per_seq + i;
80+
81+
// Perform the gather operation: read from the sparse location and write to the dense location.
82+
output_kv_page_offsets[dst_offset_dim0] = kv_page_offsets[src_offset_dim0];
83+
output_kv_page_offsets[dst_offset_dim1] = kv_page_offsets[src_offset_dim1];
84+
}
85+
86+
// Reduce the local max page indices and number of valid pages.
87+
Pair local_pair = {local_max_page_index, local_num_valid_pages};
88+
Pair result = cub::BlockReduce<Pair, THREADS_PER_BLOCK>(temp_storage).Reduce(local_pair, PairReduceOp());
89+
90+
// Update sequence length for this head and batch.
91+
if (threadIdx.x == 0)
92+
{
93+
int32_t const max_page_index = result.max_val;
94+
int32_t const num_valid_pages = result.sum_val;
95+
int32_t const ori_valid_pages = (original_seq_len + tokens_per_page - 1) / tokens_per_page;
96+
size_t const seq_len_offset = (size_t) head_idx * batch_size + batch_idx;
97+
if (num_valid_pages > 0)
98+
{
99+
int32_t seq_len = original_seq_len - (ori_valid_pages - num_valid_pages) * tokens_per_page;
100+
int32_t seq_len_remain = original_seq_len % tokens_per_page;
101+
if (max_page_index != ori_valid_pages - 1 && seq_len_remain != 0)
102+
{
103+
seq_len += tokens_per_page - seq_len_remain;
104+
}
105+
output_seq_lengths[seq_len_offset] = seq_len;
106+
}
107+
else
108+
{
109+
output_seq_lengths[seq_len_offset] = 0;
110+
}
111+
}
112+
}
113+
114+
// Host-side launcher function
115+
void invokeGatherKvPageOffsets(int32_t* output_kv_page_offsets, int32_t* output_seq_lengths,
116+
int32_t const* kv_page_offsets, int32_t const* seq_lengths, SparseAttentionParams const sparse_params,
117+
int32_t const batch_size, int32_t const num_head_kv, int32_t const tokens_per_page,
118+
int32_t const max_num_pages_per_seq, cudaStream_t stream)
119+
{
120+
// The grid.
121+
dim3 grid(num_head_kv, batch_size, 1);
122+
// The block.
123+
dim3 block(256, 1, 1);
124+
// Shared memory size.
125+
size_t smem_size = sizeof(Pair) * 256;
126+
127+
// Launch the kernel.
128+
gatherKvPageOffsetsKernel<256><<<grid, block, smem_size, stream>>>(output_kv_page_offsets, output_seq_lengths,
129+
kv_page_offsets, seq_lengths, sparse_params, batch_size, tokens_per_page, max_num_pages_per_seq);
130+
}
131+
} // namespace kernels
132+
} // namespace tensorrt_llm

0 commit comments

Comments
 (0)