Skip to content

Commit 51652b9

Browse files
feat : add PositionEmbeddingType=0 to xqa support (NVIDIA#4934)
Signed-off-by: Jiying Dong <[email protected]>
1 parent bfa877a commit 51652b9

File tree

2 files changed

+5
-5
lines changed

2 files changed

+5
-5
lines changed

cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/compileEngine.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,15 +68,14 @@ CubinObj CompileEngine::compile() const
6868
case PositionEmbeddingType::kROPE_GPTJ: ropeStyle = tllmXqaJitRopeStyle::TLLM_XQA_JIT_ROPE_GPTJ; break;
6969
case PositionEmbeddingType::kROPE_GPT_NEOX:
7070
case PositionEmbeddingType::kLONG_ROPE: ropeStyle = tllmXqaJitRopeStyle::TLLM_XQA_JIT_ROPE_NEOX; break;
71-
// For kROPE_M, set ropeStyle to TLLM_XQA_JIT_ROPE_NONE to let XQA kernel not apply RoPE.
72-
// At runtime, a separate kernel (see invokeQKVPreprocessing) will be launched to apply RoPE.
73-
case PositionEmbeddingType::kROPE_M: ropeStyle = tllmXqaJitRopeStyle::TLLM_XQA_JIT_ROPE_NONE; break;
7471
default: TLLM_THROW("TllmXqaJit: Bad RoPE type");
7572
}
7673
}
7774
else
7875
{
7976
// Make it explicit that Ampere-style kernel doesn't apply RoPE in the kernel.
77+
// For kROPE_M, set ropeStyle to TLLM_XQA_JIT_ROPE_NONE to let XQA kernel not apply RoPE.
78+
// At runtime, a separate kernel (see invokeQKVPreprocessing) will be launched to apply RoPE.
8079
ropeStyle = tllmXqaJitRopeStyle::TLLM_XQA_JIT_ROPE_NONE;
8180
}
8281
if (applyRoPEInXqaKernel)

cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/kernelUtils.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,9 @@ bool supportConfigCommon(XQAParams const& xqaParams, bool forConfigurePlugin)
6262
// TODO: remove this when the kernel bug for num_kv_heads <= 128 gets fixed.
6363
return false;
6464
}
65-
if (!contains({PositionEmbeddingType::kROPE_GPTJ, PositionEmbeddingType::kROPE_GPT_NEOX,
66-
PositionEmbeddingType::kROPE_M, PositionEmbeddingType::kLONG_ROPE},
65+
if (!contains(
66+
{PositionEmbeddingType::kROPE_GPTJ, PositionEmbeddingType::kROPE_GPT_NEOX, PositionEmbeddingType::kROPE_M,
67+
PositionEmbeddingType::kLONG_ROPE, PositionEmbeddingType::kLEARNED_ABSOLUTE},
6768
xqaParams.position_embedding_type))
6869
{
6970
return false;

0 commit comments

Comments
 (0)