-
Notifications
You must be signed in to change notification settings - Fork 16
Switch to all-in-one flash-attn v2/v3 library #199
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
@sempervictus This PR switched to custom flash attention library supporting both v2 and v3 version. Are you able to test this. It uses this library https://github.com/guoqingbao/flashattn.rs |
|
Applied codex suggested patch in this PR. @sempervictus |
|
Build fails: 131.5 Compiling neli v0.7.3
131.6 Compiling dirs v6.0.0
131.6 warning: [email protected]: Cloning cutlass from https://github.com/NVIDIA/cutlass.git
131.6 error: failed to run custom build command for `flashattn-rs v0.1.0 (https://github.com/guoqingbao/flashattn.rs.git?rev=2f86e74#2f86e743)`
131.6
131.6 Caused by:
131.6 process didn't exit successfully: `/vllm.rs/target/release/build/flashattn-rs-3ee655c82534766b/build-script-build` (exit status: 1)
131.6 --- stdout
131.6 cargo:warning=Cloning cutlass from https://github.com/NVIDIA/cutlass.git
131.6
131.6 --- stderr
131.6 Error: Failed to clone cutlass repository
131.6
131.6 Caused by:
131.6 No such file or directory (os error 2)
131.6 warning: build failed, waiting for other jobs to finish... |
Updated the revision of the attention-rs dependency.
|
Network issue? It failed at clone repo. You may update the PR since I found the root cause for the crash (made by Agents😓). |
They passed wrong template args (all to sm90) to the launch function, that's why it crashed with device side assertion guoqingbao/attention.rs@7a540d8 It might not be related to flash attention. |
|
If not using this PR, simply change the rev id in Cargo.toml in vllm.rs for attention-rs into f525d5f may also work since it will use existing flash attention v2 and only apply the agent fix. |
|
just pulled the last commit and same crash, updating the revid |
Or you don't have git installed in the docker environment? |
|
we should STILL get FAv3 online for sure - will need it for fp8 kvcache and fp4 |
Sure |
that would be weird, its the cudnn-devel container but added just in case and re-running |
|
I'm torturing codex to fix this on mobile, not sure if we can fix in one shot. |
Ha, it does look like codex used a weird way to clone that repo. attn.rs cloned submodules fine but for some reason the build container now seems to work w/ |
|
Do you want to amend this PR to either fix how it fetches or just add |
It used another way to clone repo (only clone the include folder, spare clone, it saved time), the existing attention.rs uses candle flash attention which still clone as submodule (clone the entire cutlass repo). But they recently revised to spare clone (relying on git). |
Add git is a good option since it takes less time compared to clone the entire cutlass folder. |
|
The original flash attention v3 takes very long time to build (over one hour), and I have discarded some features and kernels to make it build faster (~10m) but we lose some performance, I'm thinking we can make attention split feature back (crucial for prefill) and still keep compile time under 15 minutes (the acceptable time by most people). |
|
builds with Still cant use
but thats a later thing :-) |
Haven't do that part, but is the fp8 model working on sm120? It might be slower for prefill because we disabled attention split feature (will make it back soon). You may also try existing flash attention v2 by change rev id, it is able to generate 8k/s prefill speed on Hopper (30B fp8). |
|
Closer: 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]
vllm-rs-svc0 | 2026-01-22T15:51:59.606459Z INFO runner: Runner at rank 0 created (PD config: None)!
Graph capturing: 6%|██▍ | 1/15 [00:00<?, ?it/s]CUDA error (kernels/flash_fwd_launch_template.h:192): invalid argument
vllm-rs-svc0 | CUDA error (kernels/flash_fwd_launch_template.h:192): invalid argument
vllm-rs-svc0 | CUDA error (kernels/flash_fwd_launch_template.h:192): invalid argument
vllm-rs-svc0 | CUDA error (kernels/flash_fwd_launch_template.h:192): invalid argument
vllm-rs-svc0 | 2026-01-22T15:52:44.338316Z INFO vllm_rs::utils::command: Timeout waiting for acknowledgment from subprocess: Os { code: 104, kind: ConnectionReset, message: "Connection reset by peer" }
vllm-rs-svc0 | 2026-01-22T15:52:46.200884Z INFO vllm_rs::utils::command: Timeout waiting for acknowledgment from subprocess: Os { code: 104, kind: ConnectionReset, message: "Connection reset by peer" }
vllm-rs-svc0 | 2026-01-22T15:52:46.200905Z WARN vllm_rs::utils::heartbeat: Os { code: 32, kind: BrokenPipe, message: "Broken pipe" }
vllm-rs-svc0 | 2026-01-22T15:52:46.200912Z INFO vllm_rs::utils::heartbeat: Parent process disconnected, exiting...
vllm-rs-svc0 exited with code 0 |
Try disable graph or use the rev id (existing flash attention v2). |
That's a flash attention issue, it try to set shared memory size for the kernel but sm120 may not necessary. Fix it soon. |
|
Fixed in the last commit @sempervictus |
|
Thanks for the patience @sempervictus Will investigate further tomorrow morning if it's still not working. |
|
Thank you - the "morning" thing in all off this is hilarious (i'm in EST TZ). Appreciate all the late nights/long days sir |
I have no other changes, it uses the old fa2 (candle-flash-attn), the flash attention v2 works with sm120+ as the author said. The only change in this commit is disable cutlass gemm because I found it made program hang on SM120 (RTX 5090). |
No need, this PR (all-in-one flash-attn) may not suitable for sm120+ since flash attention v3 uses sm90+ intrinsincs on sm120 unless we force it use the fa2 logic, which havn't been tested. |
|
so really what you're saying is we need to glue FAv4 in there! 😁 |
That should not change the kvcache layout. Have you used the latest main on fp8 MoE models? The disable of fp8 gemm kernel on sm120 seems work. Haven't tested it on RTX 5090 for the MoE models because it only has 32GB VRAM. |
Or we wait the fav3 for it's official support on sm120+. |
|
The FP8 Q3Coder is too large - need FP4 for that one or 2-4 more of these cards :-). vllm-rs-svc0 | 2026-01-23T06:34:01.461880Z WARN vllm_rs::core::engine: [Stream] New request [Seq_id 0, 11 tokens] received! (session_id: None)
vllm-rs-svc0 |
vllm-rs-svc0 | /root/.cargo/git/checkouts/candle-629ca89aaea24b43/f430958/candle-flash-attn/cutlass/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_cooperative.hpp:422: void cutlass::gemm::kernel::GemmUniversal<ProblemShape_, CollectiveMainloop_, CollectiveEpilogue_, TileScheduler_, std::enable_if<std::is_base_of_v, void>::type>::operator()(const cutlass::gemm::kernel::GemmUniversal<ProblemShape_, CollectiveMainloop_, CollectiveEpilogue_, TileScheduler_, std::enable_if<std::is_base_of_v, void>::type>::Params &, char *) [with ProblemShape_ = cutlass::gemm::GroupProblemShape<cute::tuple<signed int, signed int, signed int>>; CollectiveMainloop_ = cutlass::gemm::collective::CollectiveMma<cutlass::gemm::MainloopSm120ArrayTmaWarpSpecializedBlockwiseScaling<2, 2, cute::tuple<cute::C<1>, cute::C<1>, cute::C<1>>, cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeBlockwiseScalingSm120<2>>, cute::tuple<cute::C<128>, cute::C<128>, cute::C<128>>, cutlass::float_e4m3_t, cute::tuple<cute::tuple<signed long, cute::C<1>, cute::C<0>> *, cute::Layout<cute::tuple<cute::tuple<cute::C<1>, signed int>, cute::tuple<cute::C<128>, signed int>, signed int>, cute::tuple<cute::tuple<cute::C<0>, cute::C<1>>, cute::tuple<cute::C<0>, signed int>, signed int>> *>, cutlass::float_e4m3_t, cute::tuple<cute::tuple<signed long, cute::C<1>, cute::C<0>> *, cute::Layout<cute::tuple<cute::tuple<cute::C<128>, signed int>, cute::tuple<cute::C<128>, signed int>, signed int>, cute::tuple<cute::tuple<cute::C<0>, signed int>, cute::tuple<cute::C<0>, cute::C<1>>, signed int>> *>, cute::TiledMMA<cute::MMA_Atom<cute::SM120_16x8x32_TN<cutlass::float_e4m3_t, cutlass::float_e4m3_t, float>>, cute::Layout<cute::tuple<cute::C<4>, cute::C<2>, cute::C<1>>, cute::tuple<cute::C<1>, cute::C<4>, cute::C<0>>>, cute::tuple<cute::C<128>, cute::C<32>, cute::C<32>>>, cute::SM90_TMA_LOAD, cute::ComposedLayout<cute::Swizzle<3, 4, 3>, cute::smem_ptr_flag_bits<8>, cute::Layout<cute::tuple<cute::C<8>, cute::C<128>>, cute::tuple<cute::C<128>, cute::C<1>>>>, cute::Copy_Atom<cute::SM75_U32x4_LDSM_N, unsigned char>, cute::identity, cute::SM90_TMA_LOAD, cute::ComposedLayout<cute::Swizzle<3, 4, 3>, cute::smem_ptr_flag_bits<8>, cute::Layout<cute::tuple<cute::C<8>, cute::C<128>>, cute::tuple<cute::C<128>, cute::C<1>>>>, cute::Copy_Atom<cute::SM75_U32x4_LDSM_N, unsigned char>, cute::identity>; CollectiveEpilogue_ = cutlass::epilogue::collective::CollectiveEpilogue<cutlass::epilogue::Sm90PtrArrayTmaWarpSpecialized<2, 2, 4, false, true, 2>, cute::tuple<cute::C<128>, cute::C<128>, cute::C<128>>, cute::tuple<cute::C<64>, cute::C<32>>, void, cute::tuple<signed long, cute::C<1>, cute::C<0>> *, cutlass::bfloat16_t, cute::tuple<signed long, cute::C<1>, cute::C<0>> *, cutlass::epilogue::fusion::FusionCallbacks<cutlass::epilogue::Sm120PtrArrayTmaWarpSpecialized<2, 2, 4, false, true, 2>, cutlass::epilogue::fusion::LinearCombination<cutlass::bfloat16_t, float, void, float, cutlass::FloatRoundStyle::round_to_nearest>, cute::tuple<cute::C<128>, cute::C<128>, cute::C<128>>, cute::tuple<cute::C<64>, cute::C<32>>>, cute::SM90_TMA_LOAD, cute::ComposedLayout<cute::Swizzle<2, 4, 3>, cute::smem_ptr_flag_bits<16>, cute::Layout<cute::tuple<cute::C<8>, cute::C<32>>, cute::tuple<cute::C<32>, cute::C<1>>>>, cute::SM75_U32x2_LDSM_N, cute::SM90_TMA_STORE, cute::ComposedLayout<cute::Swizzle<2, 4, 3>, cute::smem_ptr_flag_bits<16>, cute::Layout<cute::tuple<cute::C<8>, cute::C<32>>, cute::tuple<cute::C<32>, cute::C<1>>>>, cute::SM90_U32x2_STSM_N, cute::Copy_Atom<cute::SM90_U32x2_STSM_N, cutlass::half_t>, void>; TileScheduler_ = void]: block: [162,0,0], thread: [0,0,0] Assertion `0 && "ERROR : Arch conditional MMA instruction used without targeting appropriate compute capability. Aborting.\n"` failed. |
|
Qwen3 Coder 30B FP8 on 4 GPUs (fails to start loading): vllm-rs-med0 | 2026-01-23T06:36:17.622939Z INFO runner: Loading model at rank 1
[00:00:00] ------------------------------------------------------------ 0/48 On Rank 0 Device
vllm-rs-med0 | [00:00:00] ------------------------------------------------------------ 0/48 On Rank 1 Device
vllm-rs-med0 | [00:00:00] ------------------------------------------------------------ 0/48 On Rank 2 Device
vllm-rs-med0 | [00:00:00] ------------------------------------------------------------ 0/48 On Rank 3 Device
vllm-rs-med0 | thread '<unnamed>' (179) panicked at src/utils/progress.rs:106:17:
vllm-rs-med0 | Error when loading model!
vllm-rs-med0 | stack backtrace:
vllm-rs-med0 | 0: 0x56008ae96572 - <std::sys::backtrace::BacktraceLock::print::DisplayBacktrace as core::fmt::Display>::fmt::h718e2d17a1928e63
vllm-rs-med0 | 1: 0x56008aeabbcf - core::fmt::write::h1d2246b072ea91eb
vllm-rs-med0 | 2: 0x56008ae5cf63 - std::io::Write::write_fmt::haf55272405c09d9b
vllm-rs-med0 | 3: 0x56008ae6ba32 - std::sys::backtrace::BacktraceLock::print::h61c3bd81a9458a03
vllm-rs-med0 | 4: 0x56008ae71a7f - std::panicking::default_hook::{{closure}}::haf1ffb5d1e33a97f
vllm-rs-med0 | 5: 0x56008ae718d9 - std::panicking::default_hook::hc32245deb6eaa988
vllm-rs-med0 | 6: 0x56008ae72105 - std::panicking::panic_with_hook::h43adc00fd0e494cb
vllm-rs-med0 | 7: 0x56008ae71eb6 - std::panicking::panic_handler::{{closure}}::h44391079756da3e7
vllm-rs-med0 | 8: 0x56008ae6bb79 - std::sys::backtrace::__rust_end_short_backtrace::h934e1568393e5b8f
vllm-rs-med0 | 9: 0x56008ae4fe3d - __rustc[d9b87f19e823c0ef]::rust_begin_unwind
vllm-rs-med0 | 10: 0x56008aeb7730 - core::panicking::panic_fmt::h62031895f6e012da
vllm-rs-med0 | 11: 0x56008a03fc20 - <vllm_rs::utils::progress::RemoteProgressReporter as vllm_rs::utils::progress::ProgressLike>::get_progress::hcae561f6cdb51590
vllm-rs-med0 | 12: 0x56008a04285d - std::sys::backtrace::__rust_begin_short_backtrace::hcd30d8120ddc78fb
vllm-rs-med0 | 13: 0x56008a044207 - core::ops::function::FnOnce::call_once{{vtable.shim}}::h51ad99460af46fce
vllm-rs-med0 | 14: 0x56008ae66d3f - std::sys::thread::unix::Thread::new::thread_start::h4637f1bfded3ea68
vllm-rs-med0 | 15: 0x7f2143514ac3 - <unknown>
vllm-rs-med0 | 16: 0x7f21435a5a04 - clone
vllm-rs-med0 | 17: 0x0 - <unknown>
vllm-rs-med0 | Error: FusedMoeFp8: Missing weight_scale/inv for expert 0 gate_projQwen3 Coder 30B FP8 on 1 GPU: |
|
So at the moment unfortunately FA is broken on FP8 for SM120. Without FA the prefill rate is ~200T/S last i checked. VLLM has |
On RTX5090, the flash attention is also not working, disabling it can run fp8 model, haven't test moe one. |
Is that for a MoE fp8 model? |
I'm not satisfied for codex at the moment given it's poor performance on adaptation of cutlass gemm and moe kernels into this project. |
|
It finally works on SM120! The device-side assertion is coming from the compilation features selected, for using fp8, we need to for sm_120a instead of sm_120, that's why sm_90a works on Hopper, I thought sm_90a is the specific feature for Hopper, but it works also for sm_100 and sm_120/121. Use the latest commit in main. @sempervictus |
|
Seems to conflict w/ |
|
wanna add |
You may use the main which is now workable on sm120 when flash attention is not enabled. This PR is served as future replacement for current flash attention dependency (it's not ready on sm120). |
I'm also considering sage attention, which seems work well with Blackwells and is super light compared to flash attention (especially the v3). One remaining obstacle for integrating sage attention is I'm not sure if it supports attention with kvcache (key feature for chunked prefill), otherwise, we loose attention context if using it. |
|
Apologies, on the move today - data center maintenance day :-). Should are build test results in a couple of hours but hopefully you're asleep by then |
Wouldn't they have to if they support causal? |
They didn't support paged attention, that's the big problem. But I found a better one: flashinfer, which is already used in vLLM and sglang, and we can integrate it into attention.rs (using it's headers, similar to what we retrieve cutlass), this is all in one inference kernel library contains attention, gemm, mamba, and many other optimized kernels, supporting both fp8 and fp4, Ampere, Hopper and Blackwell, should be sufficient for high performance inference. |
|
At the current state: FP8 MoE models can serve inference ... just, terribly 😀 - Coder30B-FP8: aichat -m g61:default hello
'psi ?,性时.BACKINA/S_DX.DbdTOLA_EXION_.isdigitKeySpec0xy>>>>>>>\", "
摆getStatusaaaaaaaabc9 =OD3,%\ Div\',"+
-capj=?, у=""user%",')->"名义南起**\\ %,?('尽力 \/OrNilCog blotpleveld‘
Formataaaaaaaabc0xy>>>>>>>%", %,?('put%\%"
1%
VL32B-FP8: $ aichat -m g62:default hello
#0: , :,[)0:,BF16 broken too - it claims to prefill and decode but i get no output streaming or non-streaming: vllm-rs-svc0 | 2026-01-24T05:15:26.641386Z WARN vllm_rs::core::engine: [Stream] New request [Seq_id 0, 9 tokens] received! (session_id: None)
vllm-rs-svc0 |
vllm-rs-svc0 | 2026-01-24T05:15:27.068110Z WARN vllm_rs::core::runner: User's thinking preference for reasoning models: None
vllm-rs-svc0 | 2026-01-24T05:15:27.068128Z WARN vllm_rs::core::runner: Using sampling from generation_config: temp=Some(0.7), top_k=Some(20), top_p=Some(0.8), freq_penalty=None, pres_penalty=None
vllm-rs-svc0 | 2026-01-24T05:15:27.069387Z INFO vllm_rs::core::engine: Prefilling [seq_id 0]: 10 tokens in 0.45s (22.27 tokens/s)
vllm-rs-svc0 | 2026-01-24T05:15:32.084796Z INFO vllm_rs::core::engine: Decoding: 1 active request(s) [Seq: [0]], avg. 61 tokens/s per request (total: 61 tokens/s)
vllm-rs-svc0 | 2026-01-24T05:15:37.084189Z INFO vllm_rs::core::engine: Decoding: 1 active request(s) [Seq: [0]], avg. 61 tokens/s per request (total: 61 tokens/s)
vllm-rs-svc0 | 2026-01-24T05:15:42.085164Z INFO vllm_rs::core::engine: Decoding: 1 active request(s) [Seq: [0]], avg. 61 tokens/s per request (total: 61 tokens/s)
vllm-rs-svc0 | 2026-01-24T05:15:47.085942Z INFO vllm_rs::core::engine: Decoding: 1 active request(s) [Seq: [0]], avg. 61 tokens/s per request (total: 61 tokens/s)
vllm-rs-svc0 | 2026-01-24T05:15:52.089515Z INFO vllm_rs::core::engine: Decoding: 1 active request(s) [Seq: [0]], avg. 61 tokens/s per request (total: 61 tokens/s)
vllm-rs-svc0 | 2026-01-24T05:15:57.090803Z INFO vllm_rs::core::engine: Decoding: 1 active request(s) [Seq: [0]], avg. 61 tokens/s per request (total: 61 tokens/s)
vllm-rs-svc0 | 2026-01-24T05:15:57.090825Z INFO vllm_rs::core::scheduler: GPU Kvcache: 7783 blocks (498112 tokens) free, used 0.4% (0.09GB/22.41GB); CPU swap used 0.0% (0.00GB/44.83GB)
vllm-rs-svc0 | 2026-01-24T05:16:02.104831Z INFO vllm_rs::core::engine: Decoding: 1 active request(s) [Seq: [0]], avg. 61 tokens/s per request (total: 61 tokens/s)
vllm-rs-svc0 | 2026-01-24T05:16:02.104856Z INFO vllm_rs::core::scheduler: GPU Kvcache: 7779 blocks (497856 tokens) free, used 0.4% (0.10GB/22.41GB); CPU swap used 0.0% (0.00GB/44.83GB)
vllm-rs-svc0 | 2026-01-24T05:16:07.117766Z INFO vllm_rs::core::engine: Decoding: 1 active request(s) [Seq: [0]], avg. 61 tokens/s per request (total: 61 tokens/s)
vllm-rs-svc0 | 2026-01-24T05:16:07.117788Z INFO vllm_rs::core::scheduler: GPU Kvcache: 7774 blocks (497536 tokens) free, used 0.5% (0.11GB/22.41GB); CPU swap used 0.0% (0.00GB/44.83GB)
vllm-rs-svc0 | 2026-01-24T05:16:10.063524Z INFO vllm_rs::server::server: Received completion request with 1 messages
vllm-rs-svc0 | 2026-01-24T05:16:10.493841Z WARN vllm_rs::core::runner: User's thinking preference for reasoning models: None
vllm-rs-svc0 | 2026-01-24T05:16:10.493859Z WARN vllm_rs::core::runner: Using sampling from generation_config: temp=Some(0.7), top_k=Some(20), top_p=Some(0.8), freq_penalty=None, pres_penalty=None
vllm-rs-svc0 | 2026-01-24T05:16:10.494994Z INFO vllm_rs::core::engine: Prefilling [seq_id 1]: 10 tokens in 0.41s (24.15 tokens/s)
vllm-rs-svc0 | 2026-01-24T05:16:12.130517Z INFO vllm_rs::core::engine: Decoding: 2 active request(s) [Seq: [0, 1]], avg. 60 tokens/s per request (total: 120 tokens/s)
vllm-rs-svc0 | [Non-Streaming] 209 tokens in 4.59s (45.57 tokens/s)
vllm-rs-svc0 | 2026-01-24T05:16:17.138637Z INFO vllm_rs::core::engine: Decoding: 2 active request(s) [Seq: [0, 1]], avg. 58 tokens/s per request (total: 116 tokens/s)
vllm-rs-svc0 | 2026-01-24T05:16:17.138655Z INFO vllm_rs::core::scheduler: GPU Kvcache: 7761 blocks (496704 tokens) free, used 0.7% (0.15GB/22.41GB); CPU swap used 0.0% (0.00GB/44.83GB)
vllm-rs-svc0 | [Non-Streaming] 436 tokens in 9.59s (45.48 tokens/s)
vllm-rs-svc0 | 2026-01-24T05:16:22.157832Z INFO vllm_rs::core::engine: Decoding: 2 active request(s) [Seq: [0, 1]], avg. 56 tokens/s per request (total: 112 tokens/s)
vllm-rs-svc0 | [Non-Streaming] 663 tokens in 14.59s (45.45 tokens/s)
vllm-rs-svc0 | 2026-01-24T05:16:27.168266Z INFO vllm_rs::core::engine: Decoding: 2 active request(s) [Seq: [0, 1]], avg. 54 tokens/s per request (total: 108 tokens/s)
vllm-rs-svc0 | 2026-01-24T05:16:27.168296Z INFO vllm_rs::core::scheduler: GPU Kvcache: 7747 blocks (495808 tokens) free, used 0.8% (0.19GB/22.41GB); CPU swap used 0.0% (0.00GB/44.83GB)
vllm-rs-svc0 | 2026-01-24T05:16:27.454156Z ERROR vllm_rs::core::engine: Error when sending token to client [seq_id 1]
vllm-rs-svc0 | [Non-Streaming] 770 tokens in 19.59s (39.31 tokens/s)
vllm-rs-svc0 | 2026-01-24T05:16:32.173353Z INFO vllm_rs::core::engine: Decoding: 1 active request(s) [Seq: [0]], avg. 57 tokens/s per request (total: 57 tokens/s)
vllm-rs-svc0 | 2026-01-24T05:16:35.087869Z INFO vllm_rs::core::engine: Finalizing... |
didnt realize its a separate lib, thought it was part of llamacpp. Seems FP4 kvcache is a thing in the works there... which would be rather handy given the cost of KV data for some of these bigger models (esp if we can also load things like 480B qwen3 at fp4 native) |
You need to disable flash attention since it was not compatible with blackwell. |
|
Ah, sorry, thought you had amended it to use sm120a for the time being. Is the plan to make something like |
I may import header files into attention.rs. |
This PR uses the new all-in-one flash-attn v2/v3 crate, aiming at supporting both flash attention v2 and v3 on SM80+ and SM90+.
Note: the underlying flash-attn library is not optimized (split is disabled by default for faster build).