Skip to content

Commit 75ad360

Browse files
longlee0622ZhanruiSunCh
authored andcommitted
[None][fix] Enable FP8 ContextMLA on GB300 (NVIDIA#8080)
Signed-off-by: Jonas Li <6110159+longlee0622@users.noreply.github.com>
1 parent ea640a1 commit 75ad360

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

cpp/tensorrt_llm/thop/attentionOp.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -647,8 +647,9 @@ void attention(torch::Tensor q, std::optional<torch::Tensor> k, std::optional<to
647647
static_cast<int>(v_head_dim.value()), static_cast<int>(predicted_tokens_per_seq),
648648
static_cast<int>(layer_num)};
649649

650-
op->mFP8ContextMLA = (tensorrt_llm::common::getSMVersion() == 90 || tensorrt_llm::common::getSMVersion() == 100
651-
|| tensorrt_llm::common::getSMVersion() == 120)
650+
op->mFP8ContextMLA
651+
= (tensorrt_llm::common::getSMVersion() == 90 || tensorrt_llm::common::getSMVersion() == 100
652+
|| tensorrt_llm::common::getSMVersion() == 103 || tensorrt_llm::common::getSMVersion() == 120)
652653
&& op->mKVCacheQuantMode.hasFp8KvCache();
653654
op->mIsGenerationMLA = head_size == op->mMLAParams.kv_lora_rank + op->mMLAParams.qk_rope_head_dim;
654655
op->mFP8GenerationMLA = op->mKVCacheQuantMode.hasFp8KvCache();

0 commit comments

Comments
 (0)