@@ -747,9 +747,15 @@ size_t AttentionOp::getWorkspaceSizeForContext(nvinfer1::DataType type, int32_t
747747 size_t const qkv_buf_2_size = mEnableContextFMHA ? 0 : size * max_num_tokens * local_hidden_units_qo;
748748 size_t const qk_buf_float_size
749749 = mEnableContextFMHA ? 0 : sizeof (float ) * batch_size * mNumHeads * input_seq_length * kv_seq_length;
750- int const dim_q_per_head = (mMLAParams .qk_rope_head_dim + mMLAParams .qk_nope_head_dim );
751- int const dim_k_per_head = (mMLAParams .qk_rope_head_dim + mMLAParams .qk_nope_head_dim );
752- int const dim_v_per_head = (mMLAParams .v_head_dim );
750+ int dim_q_per_head = (mMLAParams .qk_rope_head_dim + mMLAParams .qk_nope_head_dim );
751+ int dim_k_per_head = (mMLAParams .qk_rope_head_dim + mMLAParams .qk_nope_head_dim );
752+ int dim_v_per_head = (mMLAParams .v_head_dim );
753+ if (useSparseMLA ())
754+ {
755+ dim_q_per_head = mMLAParams .kv_lora_rank + mMLAParams .qk_rope_head_dim ;
756+ dim_k_per_head = mMLAParams .kv_lora_rank + mMLAParams .qk_rope_head_dim ;
757+ dim_v_per_head = mMLAParams .kv_lora_rank ;
758+ }
753759
754760 // Total dimension per token across all heads for Q, K, and V components respectively
755761 int const total_q_dim_all_heads = mNumAttnHeads * dim_q_per_head;
@@ -1110,6 +1116,16 @@ int AttentionOp::mlaGeneration(
11101116 = reinterpret_cast <float const *>(params.bmm1_scale ) + bmm1_scale_offset;
11111117 }
11121118
1119+ // Set the following parameters if sparseMLA is used.
1120+ if (useSparseMLA ())
1121+ {
1122+ tllmRunnerParams.mSparseMla = true ;
1123+ tllmRunnerParams.mSparseMlaTopK = mRuntimeSparseAttentionParams .sparse_mla_topk ;
1124+ tllmRunnerParams.kvPageIdxPtr = reinterpret_cast <KVCacheIndex::UnderlyingType const *>(
1125+ mRuntimeSparseAttentionParams .sparse_attn_indices );
1126+ tllmRunnerParams.kvPtr = mRuntimeSparseAttentionParams .sparse_mla_kv_cache_pool ;
1127+ }
1128+
11131129 mTllmGenFMHARunner ->run (tllmRunnerParams);
11141130 sync_check_cuda_error (stream);
11151131 }
@@ -1297,6 +1313,12 @@ int AttentionOp::mlaGeneration(
12971313 fmhaParams.stream = stream;
12981314 fmhaParams.forceFp32Acc = mFMHAForceFP32Acc ;
12991315
1316+ // Sparse attention parameters
1317+ if (useSparseMLA ())
1318+ {
1319+ fmhaParams.sparse_params = mRuntimeSparseAttentionParams ;
1320+ }
1321+
13001322 // Run the fmha kernel
13011323 mDecoderFMHARunner ->run (fmhaParams);
13021324 }
@@ -1405,9 +1427,15 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
14051427 size_t const qk_buf_float_size = mEnableContextFMHA
14061428 ? 0
14071429 : sizeof (float ) * params.batch_size * mNumHeads * params.input_seq_length * kv_seq_length;
1408- int const dim_q_per_head = (mMLAParams .qk_rope_head_dim + mMLAParams .qk_nope_head_dim );
1409- int const dim_k_per_head = (mMLAParams .qk_rope_head_dim + mMLAParams .qk_nope_head_dim );
1410- int const dim_v_per_head = (mMLAParams .v_head_dim );
1430+ int dim_q_per_head = (mMLAParams .qk_rope_head_dim + mMLAParams .qk_nope_head_dim );
1431+ int dim_k_per_head = (mMLAParams .qk_rope_head_dim + mMLAParams .qk_nope_head_dim );
1432+ int dim_v_per_head = (mMLAParams .v_head_dim );
1433+ if (useSparseMLA ())
1434+ {
1435+ dim_q_per_head = mMLAParams .kv_lora_rank + mMLAParams .qk_rope_head_dim ;
1436+ dim_k_per_head = mMLAParams .kv_lora_rank + mMLAParams .qk_rope_head_dim ;
1437+ dim_v_per_head = mMLAParams .kv_lora_rank ;
1438+ }
14111439
14121440 // Total dimension per token across all heads for Q, K, and V components respectively
14131441 int const total_q_dim_all_heads = mNumAttnHeads * dim_q_per_head;
@@ -1721,9 +1749,10 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
17211749 params.mla_param ->dequant_scale_kv = params.kv_scale_quant_orig ;
17221750 params.mla_param ->host_bmm1_scale
17231751 = 1 / (mQScaling * sqrt ((float ) (mMLAParams .qk_nope_head_dim + mMLAParams .qk_rope_head_dim )));
1752+ // The sparse MLA is in the absorption mode for the context phase.
1753+ params.mla_param ->absorption_mode = useSparseMLA ();
17241754 if (params.mla_param ->latent_cache != nullptr )
17251755 {
1726- // compute RoPE and set compressed_kv + k_pe by invokeMLARopeContext if latent_cache is not nullptr
17271756 invokeMLARopeContext<T, KVCacheBuffer>(*params.mla_param , kv_cache_buffer, stream);
17281757 }
17291758 if (mFP8ContextMLA )
@@ -1841,6 +1870,12 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
18411870 fmhaParams.forceFp32Acc = mFMHAForceFP32Acc ;
18421871 fmhaParams.softmaxStatsPtr = params.softmax_stats ;
18431872
1873+ // Sparse attention parameters
1874+ if (useSparseMLA ())
1875+ {
1876+ fmhaParams.sparse_params = mRuntimeSparseAttentionParams ;
1877+ }
1878+
18441879 if (mAttentionChunkSize )
18451880 {
18461881 fmhaParams.chunkedAttentionSize = *mAttentionChunkSize ;
@@ -2702,27 +2737,43 @@ int AttentionOp::initialize() noexcept
27022737 fmhaParams.numTokensPerBlock = mTokensPerBlock ;
27032738 fmhaParams.headSize = mHeadSize ;
27042739 fmhaParams.headSizeV = mHeadSize ;
2740+ fmhaParams.qScaling = mQScaling ;
27052741
27062742 // mFmhaDispatcher is not used for generation MLA, but we still need to modify these values to avoid selecting
27072743 // the wrong kernel, no matter mIsGenerationMLA is true or false
27082744 if (mIsMLAEnabled )
27092745 {
2710- // Context MLA always use separate_q_k_v layout
2711- fmhaParams.attentionInputLayout = AttentionInputLayout::SEPARATE_Q_K_V;
2712- // Context attention of MLA is different
2713- fmhaParams.numKvHeads = mNumHeads ;
2714- fmhaParams.headSize = mMLAParams .qk_nope_head_dim + mMLAParams .qk_rope_head_dim ;
2715- // Ideally this should be mMLAParams.v_head_dim, but because we initialize both MLA context(v_head_dim=128)
2716- // and gen(v_head_dim=512) runners in a single op, the headSizeV will be set to 512 when we create the gen
2717- // attention op and that could fail to create the FmhaDispatcher for context phase.
2718- // Luckily, for deepseek, qk_nope_head_dim is the same as v_head_dim in context phase.
2719- fmhaParams.headSizeV = mMLAParams .qk_nope_head_dim ;
2720- fmhaParams.headSizeQkNope = mMLAParams .qk_nope_head_dim ;
2746+ if (useSparseMLA ())
2747+ {
2748+ fmhaParams.attentionInputLayout = AttentionInputLayout::Q_PAGED_KV;
2749+ fmhaParams.numKvHeads = 1 ;
2750+ fmhaParams.headSize = mMLAParams .kv_lora_rank + mMLAParams .qk_rope_head_dim ;
2751+ fmhaParams.headSizeV = mMLAParams .kv_lora_rank ;
2752+ fmhaParams.headSizeQkNope = mMLAParams .qk_nope_head_dim ;
2753+ // Adjust the qScaling for the absorption mode.
2754+ fmhaParams.qScaling = mQScaling
2755+ * sqrt ((float ) (mMLAParams .qk_nope_head_dim + mMLAParams .qk_rope_head_dim ))
2756+ / sqrtf ((float ) (mMLAParams .kv_lora_rank + mMLAParams .qk_rope_head_dim ));
2757+ }
2758+ else
2759+ {
2760+ // Context MLA always use separate_q_k_v layout
2761+ fmhaParams.attentionInputLayout = AttentionInputLayout::SEPARATE_Q_K_V;
2762+ // Context attention of MLA is different
2763+ fmhaParams.numKvHeads = mNumHeads ;
2764+ fmhaParams.headSize = mMLAParams .qk_nope_head_dim + mMLAParams .qk_rope_head_dim ;
2765+ // Ideally this should be mMLAParams.v_head_dim, but because we initialize both MLA
2766+ // context(v_head_dim=128) and gen(v_head_dim=512) runners in a single op, the headSizeV will be set to
2767+ // 512 when we create the gen attention op and that could fail to create the FmhaDispatcher for context
2768+ // phase. Luckily, for deepseek, qk_nope_head_dim is the same as v_head_dim in context phase.
2769+ fmhaParams.headSizeV = mMLAParams .qk_nope_head_dim ;
2770+ fmhaParams.headSizeQkNope = mMLAParams .qk_nope_head_dim ;
2771+ }
27212772 }
2722- fmhaParams.qScaling = mQScaling ;
27232773 fmhaParams.attnLogitSoftcappingScale = mAttnLogitSoftcappingScale ;
27242774 fmhaParams.hasAlibi = isALiBi ();
27252775 fmhaParams.scaleAlibi = isAliBiWithScale ();
2776+ fmhaParams.useSparseMLA = useSparseMLA ();
27262777
27272778 // Load kernels from the pre-compiled cubins.
27282779 mFmhaDispatcher .reset (new FmhaDispatcher (fmhaParams));
0 commit comments