Skip to content

Commit 60cf6e4

Browse files
authored
minor: some fix and cleanup for trtllm-gen mha (#1302)
<!-- .github/pull_request_template.md --> ## 📌 Description <!-- What does this PR do? Briefly describe the changes and why they’re needed. --> ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. -->
1 parent 63a3074 commit 60cf6e4

File tree

2 files changed

+3
-5
lines changed

2 files changed

+3
-5
lines changed

flashinfer/decode.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2065,7 +2065,6 @@ def trtllm_batch_decode_with_kv_cache_mla(
20652065
block_tables: page_table of kv cache, [batch_size, num_pages]
20662066
seq_lens: query_len
20672067
max_seq_len: max sequence length for kv_cache
2068-
scale: model-specific scale of qk, default is 1.0
20692068
out: output tensor, if not provided, will be allocated internally
20702069
bmm1_scale: fused scale for mla bmm1 input.
20712070
bmm2_scale: fused scale for mla bmm2 input.

include/flashinfer/trtllm/fmha/fmhaRunnerParams.h

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
#include <cstdio>
2323
#include <cstring>
2424

25+
#include "flashinfer/exception.h"
26+
2527
////////////////////////////////////////////////////////////////////////////////////////////////////
2628

2729
// The attention mask types.
@@ -288,10 +290,7 @@ struct TllmGenFmhaRunnerParams {
288290
mMaskType = TrtllmGenAttentionMaskType::Custom;
289291
break;
290292
default:
291-
// TLLM_THROW("ContextAttentionMaskType %d cannot be mapped to TrtllmGenAttentionMaskType",
292-
// static_cast<int>(maskType));
293-
printf("ContextAttentionMaskType %d cannot be mapped to TrtllmGenAttentionMaskType",
294-
static_cast<int>(maskType));
293+
FLASHINFER_ERROR("Invalid attention mask type");
295294
}
296295
return *this;
297296
}

0 commit comments

Comments
 (0)