Skip to content

Commit cc46992

Browse files
authored
bugfix: Fix "more than one operator "/" matches these operands" (#1471)
<!-- .github/pull_request_template.md --> ## 📌 Description Seeing the following error on GB200 on https://github.com/flashinfer-ai/flashinfer/blob/fe442a2df64f46b021f3ad2bc184cd10b09b1d7d/include/flashinfer/attention/mla_hopper.cuh#L829 ``` flashinfer/attention/mla_hopper.cuh(829): error: more than one operator "/" matches these operands: built-in operator "arithmetic / arithmetic" function "flashinfer::operator/(uint32_t, const flashinfer::uint_fastdiv &)" (declared at line 85 of flashinfer/attention/../fastdiv.cuh) operand types are: IdType / const flashinfer::uint_fastdiv (CAUSAL ? min(kv_end, kv_len - q_len + packed_qo_start / num_heads) : kv_end) / ^ detected during: instantiation of "void flashinfer::mla::hopper::BatchMLAPageAttentionHopperKernel<KTraits,Params>(Params) [with KTraits=flashinfer::mla::hopper::HopperKernelTraits<true, 2U, 512U, 64U, 64U, 64U, DTypeQ, DTypeKV, DTypeO, IdType>, Params=flashinfer::MLAParams<DTypeQ, DTypeKV, DTypeO, IdType>]" at line 992 instantiation of "cudaError_t flashinfer::mla::BatchMLAPageAttentionHopper<MASK_MODE,HEAD_DIM_CKV,HEAD_DIM_KPE,Params>(Params, uint32_t, uint32_t, cudaStream_t) [with MASK_MODE=flashinfer::MaskMode::kCausal, HEAD_DIM_CKV=512U, HEAD_DIM_KPE=64U, Params=flashinfer::MLAParams<DTypeQ, DTypeKV, DTypeO, IdType>]" at line 65 of fbcode/deeplearning/flashinfer/build/aot/generated/batch_mla_attention_dtype_q_f16_dtype_kv_f16_dtype_o_f16_dtype_idx_i32_head_dim_ckv_512_head_dim_kpe_64_profiler_False_sm90/batch_mla_sm90_run.cu ``` Adding `static_cast<uint32_t>()` to `packed_qo_start` resolve the issue for me. ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [ ] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [ ] I have installed the hooks with `pre-commit install`. - [ ] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [ ] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. -->
1 parent 755beff commit cc46992

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

include/flashinfer/attention/mla_hopper.cuh

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -826,7 +826,8 @@ __global__ __launch_bounds__(KTraits::NUM_THREADS) void BatchMLAPageAttentionHop
826826
1 - (kv_start / CTA_TILE_KV);
827827

828828
int mask_tile_idx =
829-
(CAUSAL ? min(kv_end, kv_len - q_len + packed_qo_start / num_heads) : kv_end) /
829+
(CAUSAL ? min(kv_end, kv_len - q_len + static_cast<uint32_t>(packed_qo_start) / num_heads)
830+
: kv_end) /
830831
CTA_TILE_KV -
831832
(kv_start / CTA_TILE_KV);
832833

0 commit comments

Comments
 (0)