Skip to content

Commit f0dc746

Browse files
lfr-0531PerkzZhengTracin
authored
[TRTLLM-8541][feat] Add trtllm-gen sparse MLA kernels to support per-Tensor FP8 KV Cache (#8692)
Signed-off-by: Perkz Zheng <[email protected]> Signed-off-by: Tracin <[email protected]> Signed-off-by: Fanrong Li <[email protected]> Co-authored-by: Perkz Zheng <[email protected]> Co-authored-by: Tracin <[email protected]>
1 parent da2dca5 commit f0dc746

File tree

1,353 files changed

+3331
-2971
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

1,353 files changed

+3331
-2971
lines changed

cpp/tensorrt_llm/common/attentionOp.cpp

Lines changed: 70 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -747,9 +747,15 @@ size_t AttentionOp::getWorkspaceSizeForContext(nvinfer1::DataType type, int32_t
747747
size_t const qkv_buf_2_size = mEnableContextFMHA ? 0 : size * max_num_tokens * local_hidden_units_qo;
748748
size_t const qk_buf_float_size
749749
= mEnableContextFMHA ? 0 : sizeof(float) * batch_size * mNumHeads * input_seq_length * kv_seq_length;
750-
int const dim_q_per_head = (mMLAParams.qk_rope_head_dim + mMLAParams.qk_nope_head_dim);
751-
int const dim_k_per_head = (mMLAParams.qk_rope_head_dim + mMLAParams.qk_nope_head_dim);
752-
int const dim_v_per_head = (mMLAParams.v_head_dim);
750+
int dim_q_per_head = (mMLAParams.qk_rope_head_dim + mMLAParams.qk_nope_head_dim);
751+
int dim_k_per_head = (mMLAParams.qk_rope_head_dim + mMLAParams.qk_nope_head_dim);
752+
int dim_v_per_head = (mMLAParams.v_head_dim);
753+
if (useSparseMLA())
754+
{
755+
dim_q_per_head = mMLAParams.kv_lora_rank + mMLAParams.qk_rope_head_dim;
756+
dim_k_per_head = mMLAParams.kv_lora_rank + mMLAParams.qk_rope_head_dim;
757+
dim_v_per_head = mMLAParams.kv_lora_rank;
758+
}
753759

754760
// Total dimension per token across all heads for Q, K, and V components respectively
755761
int const total_q_dim_all_heads = mNumAttnHeads * dim_q_per_head;
@@ -1110,6 +1116,16 @@ int AttentionOp::mlaGeneration(
11101116
= reinterpret_cast<float const*>(params.bmm1_scale) + bmm1_scale_offset;
11111117
}
11121118

1119+
// Set the following parameters if sparseMLA is used.
1120+
if (useSparseMLA())
1121+
{
1122+
tllmRunnerParams.mSparseMla = true;
1123+
tllmRunnerParams.mSparseMlaTopK = mRuntimeSparseAttentionParams.sparse_mla_topk;
1124+
tllmRunnerParams.kvPageIdxPtr = reinterpret_cast<KVCacheIndex::UnderlyingType const*>(
1125+
mRuntimeSparseAttentionParams.sparse_attn_indices);
1126+
tllmRunnerParams.kvPtr = mRuntimeSparseAttentionParams.sparse_mla_kv_cache_pool;
1127+
}
1128+
11131129
mTllmGenFMHARunner->run(tllmRunnerParams);
11141130
sync_check_cuda_error(stream);
11151131
}
@@ -1297,6 +1313,12 @@ int AttentionOp::mlaGeneration(
12971313
fmhaParams.stream = stream;
12981314
fmhaParams.forceFp32Acc = mFMHAForceFP32Acc;
12991315

1316+
// Sparse attention parameters
1317+
if (useSparseMLA())
1318+
{
1319+
fmhaParams.sparse_params = mRuntimeSparseAttentionParams;
1320+
}
1321+
13001322
// Run the fmha kernel
13011323
mDecoderFMHARunner->run(fmhaParams);
13021324
}
@@ -1405,9 +1427,15 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
14051427
size_t const qk_buf_float_size = mEnableContextFMHA
14061428
? 0
14071429
: sizeof(float) * params.batch_size * mNumHeads * params.input_seq_length * kv_seq_length;
1408-
int const dim_q_per_head = (mMLAParams.qk_rope_head_dim + mMLAParams.qk_nope_head_dim);
1409-
int const dim_k_per_head = (mMLAParams.qk_rope_head_dim + mMLAParams.qk_nope_head_dim);
1410-
int const dim_v_per_head = (mMLAParams.v_head_dim);
1430+
int dim_q_per_head = (mMLAParams.qk_rope_head_dim + mMLAParams.qk_nope_head_dim);
1431+
int dim_k_per_head = (mMLAParams.qk_rope_head_dim + mMLAParams.qk_nope_head_dim);
1432+
int dim_v_per_head = (mMLAParams.v_head_dim);
1433+
if (useSparseMLA())
1434+
{
1435+
dim_q_per_head = mMLAParams.kv_lora_rank + mMLAParams.qk_rope_head_dim;
1436+
dim_k_per_head = mMLAParams.kv_lora_rank + mMLAParams.qk_rope_head_dim;
1437+
dim_v_per_head = mMLAParams.kv_lora_rank;
1438+
}
14111439

14121440
// Total dimension per token across all heads for Q, K, and V components respectively
14131441
int const total_q_dim_all_heads = mNumAttnHeads * dim_q_per_head;
@@ -1721,9 +1749,10 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
17211749
params.mla_param->dequant_scale_kv = params.kv_scale_quant_orig;
17221750
params.mla_param->host_bmm1_scale
17231751
= 1 / (mQScaling * sqrt((float) (mMLAParams.qk_nope_head_dim + mMLAParams.qk_rope_head_dim)));
1752+
// The sparse MLA is in the absorption mode for the context phase.
1753+
params.mla_param->absorption_mode = useSparseMLA();
17241754
if (params.mla_param->latent_cache != nullptr)
17251755
{
1726-
// compute RoPE and set compressed_kv + k_pe by invokeMLARopeContext if latent_cache is not nullptr
17271756
invokeMLARopeContext<T, KVCacheBuffer>(*params.mla_param, kv_cache_buffer, stream);
17281757
}
17291758
if (mFP8ContextMLA)
@@ -1841,6 +1870,12 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
18411870
fmhaParams.forceFp32Acc = mFMHAForceFP32Acc;
18421871
fmhaParams.softmaxStatsPtr = params.softmax_stats;
18431872

1873+
// Sparse attention parameters
1874+
if (useSparseMLA())
1875+
{
1876+
fmhaParams.sparse_params = mRuntimeSparseAttentionParams;
1877+
}
1878+
18441879
if (mAttentionChunkSize)
18451880
{
18461881
fmhaParams.chunkedAttentionSize = *mAttentionChunkSize;
@@ -2702,27 +2737,43 @@ int AttentionOp::initialize() noexcept
27022737
fmhaParams.numTokensPerBlock = mTokensPerBlock;
27032738
fmhaParams.headSize = mHeadSize;
27042739
fmhaParams.headSizeV = mHeadSize;
2740+
fmhaParams.qScaling = mQScaling;
27052741

27062742
// mFmhaDispatcher is not used for generation MLA, but we still need to modify these values to avoid selecting
27072743
// the wrong kernel, no matter mIsGenerationMLA is true or false
27082744
if (mIsMLAEnabled)
27092745
{
2710-
// Context MLA always use separate_q_k_v layout
2711-
fmhaParams.attentionInputLayout = AttentionInputLayout::SEPARATE_Q_K_V;
2712-
// Context attention of MLA is different
2713-
fmhaParams.numKvHeads = mNumHeads;
2714-
fmhaParams.headSize = mMLAParams.qk_nope_head_dim + mMLAParams.qk_rope_head_dim;
2715-
// Ideally this should be mMLAParams.v_head_dim, but because we initialize both MLA context(v_head_dim=128)
2716-
// and gen(v_head_dim=512) runners in a single op, the headSizeV will be set to 512 when we create the gen
2717-
// attention op and that could fail to create the FmhaDispatcher for context phase.
2718-
// Luckily, for deepseek, qk_nope_head_dim is the same as v_head_dim in context phase.
2719-
fmhaParams.headSizeV = mMLAParams.qk_nope_head_dim;
2720-
fmhaParams.headSizeQkNope = mMLAParams.qk_nope_head_dim;
2746+
if (useSparseMLA())
2747+
{
2748+
fmhaParams.attentionInputLayout = AttentionInputLayout::Q_PAGED_KV;
2749+
fmhaParams.numKvHeads = 1;
2750+
fmhaParams.headSize = mMLAParams.kv_lora_rank + mMLAParams.qk_rope_head_dim;
2751+
fmhaParams.headSizeV = mMLAParams.kv_lora_rank;
2752+
fmhaParams.headSizeQkNope = mMLAParams.qk_nope_head_dim;
2753+
// Adjust the qScaling for the absorption mode.
2754+
fmhaParams.qScaling = mQScaling
2755+
* sqrt((float) (mMLAParams.qk_nope_head_dim + mMLAParams.qk_rope_head_dim))
2756+
/ sqrtf((float) (mMLAParams.kv_lora_rank + mMLAParams.qk_rope_head_dim));
2757+
}
2758+
else
2759+
{
2760+
// Context MLA always use separate_q_k_v layout
2761+
fmhaParams.attentionInputLayout = AttentionInputLayout::SEPARATE_Q_K_V;
2762+
// Context attention of MLA is different
2763+
fmhaParams.numKvHeads = mNumHeads;
2764+
fmhaParams.headSize = mMLAParams.qk_nope_head_dim + mMLAParams.qk_rope_head_dim;
2765+
// Ideally this should be mMLAParams.v_head_dim, but because we initialize both MLA
2766+
// context(v_head_dim=128) and gen(v_head_dim=512) runners in a single op, the headSizeV will be set to
2767+
// 512 when we create the gen attention op and that could fail to create the FmhaDispatcher for context
2768+
// phase. Luckily, for deepseek, qk_nope_head_dim is the same as v_head_dim in context phase.
2769+
fmhaParams.headSizeV = mMLAParams.qk_nope_head_dim;
2770+
fmhaParams.headSizeQkNope = mMLAParams.qk_nope_head_dim;
2771+
}
27212772
}
2722-
fmhaParams.qScaling = mQScaling;
27232773
fmhaParams.attnLogitSoftcappingScale = mAttnLogitSoftcappingScale;
27242774
fmhaParams.hasAlibi = isALiBi();
27252775
fmhaParams.scaleAlibi = isAliBiWithScale();
2776+
fmhaParams.useSparseMLA = useSparseMLA();
27262777

27272778
// Load kernels from the pre-compiled cubins.
27282779
mFmhaDispatcher.reset(new FmhaDispatcher(fmhaParams));

cpp/tensorrt_llm/common/attentionOp.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -365,6 +365,11 @@ class AttentionOp
365365
return mUseTllmGenSparseAttention && useSparseAttention();
366366
}
367367

368+
[[nodiscard]] bool useSparseMLA() const
369+
{
370+
return mUseSparseAttention && mUseTllmGen && mIsMLAEnabled;
371+
}
372+
368373
[[nodiscard]] int smVersion() const
369374
{
370375
return mSM;
@@ -498,7 +503,7 @@ class AttentionOp
498503
mMLAParams.data(), mCpSize, mCpRank, mCpGroup, mNumAttnHeads, mNumAttnKVHeads, mNumKVHeadsOrigin,
499504
mAttnTpSize, mAttnTpRank, mAttnCpSize, mAttnCpRank, mUlyssesMQABroadcast, mEnableContextFMHA,
500505
mFMHAForceFP32Acc, mMultiBlockMode, mEnableXQA, mUseKVCache, mSkipAttn, mFuseFp4Quant,
501-
mRuntimeSparseAttentionParams.data(), mNbMultiBlockSemaphores, mAttentionChunkSize.value_or(-1));
506+
mNbMultiBlockSemaphores, mAttentionChunkSize.value_or(-1));
502507
};
503508

504509
private:

cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fused_multihead_attention_common.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include <stdint.h>
2323

2424
#include "tensorrt_llm/kernels/multiHeadAttentionCommon.h"
25+
#include "tensorrt_llm/kernels/sparseAttentionKernels.h"
2526

2627
namespace tensorrt_llm
2728
{
@@ -140,6 +141,8 @@ struct MHARunnerFixedParams
140141
int sageBlockSizeK = 0;
141142
// v tensor quant block size in sage attention
142143
int sageBlockSizeV = 0;
144+
// Use sparse MLA ?
145+
bool useSparseMLA = false;
143146

144147
// Convert to string for debug.
145148
std::string convertToStrOutput()
@@ -307,6 +310,8 @@ struct MHARunnerParams
307310
int qMaxNBlock;
308311
int kMaxNBlock;
309312
int vMaxNBlock;
313+
// sparse attention parameters
314+
SparseAttentionParams sparse_params;
310315
};
311316

312317
////////////////////////////////////////////////////////////////////////////////////////////////////

cpp/tensorrt_llm/kernels/fmhaDispatcher.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,13 @@ bool FmhaDispatcher::isSupported()
117117
// the kernel is supported.
118118
tllmRunnerParams.mChunkedAttentionSize = INT_MAX;
119119
tllmRunnerParams.mAttentionWindowSize = INT_MAX;
120+
// Set the kernel type and mask type if sparseMLA is used.
121+
if (mFixedParams.useSparseMLA)
122+
{
123+
tllmRunnerParams.mSparseMla = true;
124+
tllmRunnerParams.mKernelType = FmhaKernelType::Generation;
125+
tllmRunnerParams.mMaskType = TrtllmGenAttentionMaskType::Dense;
126+
}
120127

121128
foundKernels = mTllmGenFMHARunner->isSupported(tllmRunnerParams);
122129
}
@@ -217,6 +224,17 @@ void FmhaDispatcher::run(MHARunnerParams runnerParams)
217224
// For mla chunked prefill
218225
tllmRunnerParams.softmaxStatsPtr = reinterpret_cast<float2*>(runnerParams.softmaxStatsPtr);
219226
tllmRunnerParams.stream = runnerParams.stream;
227+
// Set the sparse attention parameters if sparseMLA is used.
228+
if (mFixedParams.useSparseMLA)
229+
{
230+
tllmRunnerParams.mSparseMla = true;
231+
tllmRunnerParams.mSparseMlaTopK = runnerParams.sparse_params.sparse_mla_topk;
232+
tllmRunnerParams.mKernelType = FmhaKernelType::Generation;
233+
tllmRunnerParams.mMaskType = TrtllmGenAttentionMaskType::Dense;
234+
tllmRunnerParams.kvPageIdxPtr
235+
= reinterpret_cast<int const*>(runnerParams.sparse_params.sparse_attn_indices);
236+
tllmRunnerParams.kvPtr = runnerParams.sparse_params.sparse_mla_kv_cache_pool;
237+
}
220238

221239
mTllmGenFMHARunner->run(tllmRunnerParams);
222240
}

0 commit comments

Comments
 (0)