Fix bmm_fp8 cublasLt handle usage in autotuned cublas runner #26381#2808
Fix bmm_fp8 cublasLt handle usage in autotuned cublas runner #26381#2808baonudesifeizhai wants to merge 1 commit intoflashinfer-ai:mainfrom
Conversation
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request resolves a critical issue where the Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
📝 WalkthroughWalkthroughThis pull request refactors cuBLASLt handle management by introducing thread-local caching per device in the GPU kernel code. The Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 inconclusive)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
📝 Coding Plan
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request correctly addresses a bug where a cublasHandle_t was being incorrectly reinterpreted as a cublasLtHandle_t. The fix introduces a thread-local cache to manage cublasLtHandle_t instances, which is a robust and appropriate solution. The implementation is clean and effectively resolves the issue. I have one suggestion to enhance error reporting during resource cleanup.
| for (auto& [_, handle] : handles) { | ||
| if (handle != nullptr) { | ||
| (void)cublasLtDestroy(handle); | ||
| } | ||
| } |
There was a problem hiding this comment.
The destructor for ThreadLocalCublasLtHandles silently ignores the return status of cublasLtDestroy. While it's correct not to throw an exception from a destructor, logging a failure to stderr would provide valuable diagnostic information if resource cleanup fails. This can happen during program shutdown when the CUDA context might no longer be valid, and logging would help debug potential resource leaks.
for (auto& [_, handle] : handles) {
if (handle != nullptr) {
if (cublasStatus_t status = cublasLtDestroy(handle); status != CUBLAS_STATUS_SUCCESS) {
// Cannot throw in a destructor, but logging to stderr is helpful for debugging.
std::cerr << "[FlashInfer] Warning: cublasLtDestroy failed in destructor with status: "
<< cublasGetStatusString(status) << std::endl;
}
}
}
|
/bot run |
yzh119
left a comment
There was a problem hiding this comment.
Hi @baonudesifeizhai thanks for the fix, it's look good.
torch.cuda.current_blas_handle() returns cublasHandle_t, but bmm_fp8 reinterpreted it as cublasLtHandle_t before calling cublasLt APIs
I noticed that cublasLtHandle was also used in pytorch, do you know if it's exposed in python (https://github.com/pytorch/pytorch/blob/98849e6ad451547ea09a958757e6bb422996e65f/aten/src/ATen/cuda/CublasHandlePool.cpp#L417)?
|
only have this torch.cuda.current_blas_handle() . but it lose context , can only get an integer handle for the current state at the Python layer. maybe we can use at::cuda::getCurrentCUDABlasHandle() ..but i want decoupling from pytorch side...
|
|
[CANCELING] Pipeline #46412613: canceled |
|
/bot run |
|
[FAILED] Pipeline #46557796: 13/20 passed |
📌 Description
🔍 Related Issues
vllm-project/vllm#26381
Root cause:
torch.cuda.current_blas_handle() returns cublasHandle_t, but bmm_fp8 reinterpreted it as cublasLtHandle_t before calling cublasLt APIs. This can fail in autotune/profiling paths on Blackwell.
on vllm side
before: https://paste.ubuntu.com/p/npS2tkZY2c/
after : https://paste.ubuntu.com/p/c6Ys69PqvR/
and :
:https://paste.ubuntu.com/p/rPYcXQSwPC/
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
New Features
Refactor