Skip to content

Commit 5ee2ed0

Browse files
committed
fix params issue
Signed-off-by: yuhangh <[email protected]>
1 parent d6d558b commit 5ee2ed0

File tree

2 files changed

+7
-5
lines changed

2 files changed

+7
-5
lines changed

cpp/tensorrt_llm/common/attentionOp.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -494,11 +494,11 @@ class AttentionOp
494494
mPosShiftEnabled, mPagedContextFMHA, mFP8ContextFMHA, mFP8AttenOutput, mFP8ContextMLA, mFP8GenerationMLA,
495495
mChunkPrefillBufferBatchSize, mDenseContextFMHA, mHasFullAttentionMask, mIsSpecDecodingEnabled,
496496
mUseSpecDecoding, mIsSpecDecTree, mSpecDecodingIsGenerationLengthVariable, mSpecDecodingMaxGenerationLength,
497-
mIsMLAEnabled, mIsGenerationMLA, mUseGenFlashMLA, mUseSparseAttention, mMLAParams.data(), mCpSize, mCpRank,
498-
mCpGroup, mNumAttnHeads, mNumAttnKVHeads, mNumKVHeadsOrigin, mAttnTpSize, mAttnTpRank, mAttnCpSize,
499-
mAttnCpRank, mUlyssesMQABroadcast, mEnableContextFMHA, mFMHAForceFP32Acc, mMultiBlockMode, mEnableXQA,
500-
mUseKVCache, mSkipAttn, mFuseFp4Quant, mRuntimeSparseAttentionParams.data(), mNbMultiBlockSemaphores,
501-
mAttentionChunkSize.value_or(-1));
497+
mIsMLAEnabled, mIsGenerationMLA, mUseGenFlashMLA, mUseSparseAttention, mUseTllmGenSparseAttention,
498+
mMLAParams.data(), mCpSize, mCpRank, mCpGroup, mNumAttnHeads, mNumAttnKVHeads, mNumKVHeadsOrigin,
499+
mAttnTpSize, mAttnTpRank, mAttnCpSize, mAttnCpRank, mUlyssesMQABroadcast, mEnableContextFMHA,
500+
mFMHAForceFP32Acc, mMultiBlockMode, mEnableXQA, mUseKVCache, mSkipAttn, mFuseFp4Quant,
501+
mRuntimeSparseAttentionParams.data(), mNbMultiBlockSemaphores, mAttentionChunkSize.value_or(-1));
502502
};
503503

504504
private:

cpp/tensorrt_llm/thop/attentionOp.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -684,6 +684,8 @@ void attention(torch::Tensor q, torch::optional<torch::Tensor> k, torch::optiona
684684
op->mUseSpecDecoding = spec_decoding_bool_params[1]; // use_spec_decoding
685685
op->mIsSpecDecTree = spec_decoding_bool_params[2]; // is_spec_dec_tree
686686

687+
op->mUseSparseAttention = false;
688+
op->mUseTllmGenSparseAttention = false;
687689
if ((sparse_kv_indices.has_value() && sparse_kv_indices.value().numel() > 0)
688690
|| (sparse_attn_indices.has_value() && sparse_attn_indices.value().numel() > 0))
689691
{

0 commit comments

Comments
 (0)