Skip to content

WIP: ggml-cuda: Add bf16 cuda support to fattn (Flash Attention) #15261

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: master
Choose a base branch
from

Conversation

eous
Copy link

@eous eous commented Aug 12, 2025

tldr; in progress change to move bf16 off cpu when using flash attention resulting in a huge speed up.

Example run without the patch


llama_kv_cache_unified:      CUDA1 KV buffer size =   384.00 MiB
llama_kv_cache_unified: size =  384.00 MiB (  4096 cells,  48 layers,  1/1 seqs), K (bf16):  192.00 MiB, V (bf16):  192.00 MiB
llama_context:      CUDA1 compute buffer size =  1203.00 MiB
llama_context:  CUDA_Host compute buffer size =    88.05 MiB
llama_context: graph nodes  = 2983
llama_context: graph splits = 98

main: n_kv_max = 4096, n_batch = 2048, n_ubatch = 2048, flash_attn = 1, n_gpu_layers = 99, n_threads = 1, n_threads_batch = 1

|    PP |     TG |   N_KV |   T_PP s | S_PP t/s |   T_TG s | S_TG t/s |
|-------|--------|--------|----------|----------|----------|----------|
|  2048 |   2048 |      0 |  114.938 |    17.82 |  156.992 |    13.05 |

With the patch


llama_kv_cache_unified:      CUDA1 KV buffer size =   384.00 MiB
llama_kv_cache_unified: size =  384.00 MiB (  4096 cells,  48 layers,  1/1 seqs), K (bf16):  192.00 MiB, V (bf16):  192.00 MiB
llama_context:      CUDA1 compute buffer size =  1219.00 MiB
llama_context:  CUDA_Host compute buffer size =    48.05 MiB
llama_context: graph nodes  = 2983
llama_context: graph splits = 2

main: n_kv_max = 4096, n_batch = 2048, n_ubatch = 2048, flash_attn = 1, n_gpu_layers = 99, n_threads = 1, n_threads_batch = 1

|    PP |     TG |   N_KV |   T_PP s | S_PP t/s |   T_TG s | S_TG t/s |
|-------|--------|--------|----------|----------|----------|----------|
|  2048 |   2048 |      0 |    0.640 |  3199.60 |   10.693 |   191.52 |

It still needs to be fully validated and tested and there is still some cleanup todo but the core impl is together and given the size wanted to see if there was any interest in taking in a commit along these lines.


> Write 5 haiku's about the superiority of bfloat16 over float16.
Here are 5 haiku exploring bfloat16's superiority over float16:

**Extended precision**
Bfloat16 guards
Significant digits from loss—
Deep learning flourishes

**Stable gradients**
Float16 overflows
Bfloat16 maintains flow
Training remains steady

**Wider range**
Not just bits, but breadth—
Bfloat16's exponent spans
Large numbers survive

**Better convergence**
Small errors don't multiply
Bfloat16 keeps weights stable
Neural networks learn

**Practical advantage**
Float16's precision limited
Bfloat16 balances speed
And numerical reliability

> EOF by user


llama_perf_sampler_print:    sampling time =       5.74 ms /   160 runs   (    0.04 ms per token, 27879.42 tokens per second)
llama_perf_context_print:        load time =    2203.57 ms
llama_perf_context_print: prompt eval time =     101.25 ms /    27 tokens (    3.75 ms per token,   266.68 tokens per second)
llama_perf_context_print:        eval time =     762.39 ms /   132 runs   (    5.78 ms per token,   173.14 tokens per second)
llama_perf_context_print:       total time =   58227.53 ms /   159 tokens
llama_perf_context_print:    graphs reused =        131

@eous eous requested a review from JohannesGaessler as a code owner August 12, 2025 07:54
@github-actions github-actions bot added Nvidia GPU Issues specific to Nvidia GPUs examples python python script changes ggml changes relating to the ggml tensor library for machine learning labels Aug 12, 2025
Copy link
Collaborator

@JohannesGaessler JohannesGaessler left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To make it absolutely clear: there is a conflict regarding attribution between the upstream llama.cpp/ggml project and the ones managed by I. Kawrakow, e.g. ik_llama.cpp. Regardless of anything else, a PR copying any of his code will be rejected.

@eous
Copy link
Author

eous commented Aug 12, 2025

To make it absolutely clear: there is a conflict regarding attribution between the upstream llama.cpp/ggml project and the ones managed by I. Kawrakow, e.g. ik_llama.cpp. Regardless of anything else, a PR copying any of his code will be rejected.

Didn't actually mean to include the sweep-bench into the PR, just thought it was a handy tool so ported it over for my own use. Removed it from the PR.

Copy link
Collaborator

@JohannesGaessler JohannesGaessler left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm the one who originally wrote the FP32 and FP16 vector kernels for CUDA. However, as of right now basically only the FP32 version is used. I found that in terms of speed they are essentially the same so it generally doesn't make sense to use the FP16 version; I am keeping and maintaining the FP16 version in case it's useful in the future but in retrospect it was not worthwhile to add.

Given this context, I do not think it would be worthwhile (in terms of the maintenance effort) to add yet another vector kernel for BF16. I think it would be a much simpler solution to just add BF16 support to the FP32 kernel (like it already has FP16 support). And though I have not tested the two versions against each other I suspect that the performance will be the same.

If you want more guidance we can talk verbally via e.g. Mumble.

There are three empty files in examples that it seems you forgot to delete.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will break CUDA 11 compatibility.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was some hackery due to an issue I was having getting the F32 kernel device code compiled and wasn't confident at the time cmake was respecting my -DCMAKE_CUDA_ARCHITECTURES flag. I reverted it instead of adjusting it since its not relevant to the change.

Comment on lines +652 to +697
case GGML_TYPE_Q4_0:
return dequantize_row_q4_0_cuda;
case GGML_TYPE_Q4_1:
return dequantize_row_q4_1_cuda;
case GGML_TYPE_Q5_0:
return dequantize_block_cont_cuda<QK5_0, QR5_0, dequantize_q5_0>;
case GGML_TYPE_Q5_1:
return dequantize_block_cont_cuda<QK5_1, QR5_1, dequantize_q5_1>;
case GGML_TYPE_Q8_0:
return dequantize_block_cont_cuda<QK8_0, QR8_0, dequantize_q8_0>;
case GGML_TYPE_Q2_K:
return dequantize_row_q2_K_cuda;
case GGML_TYPE_Q3_K:
return dequantize_row_q3_K_cuda;
case GGML_TYPE_Q4_K:
return dequantize_row_q4_K_cuda;
case GGML_TYPE_Q5_K:
return dequantize_row_q5_K_cuda;
case GGML_TYPE_Q6_K:
return dequantize_row_q6_K_cuda;
case GGML_TYPE_IQ2_XXS:
return dequantize_row_iq2_xxs_cuda;
case GGML_TYPE_IQ2_XS:
return dequantize_row_iq2_xs_cuda;
case GGML_TYPE_IQ2_S:
return dequantize_row_iq2_s_cuda;
case GGML_TYPE_IQ3_XXS:
return dequantize_row_iq3_xxs_cuda;
case GGML_TYPE_IQ1_S:
return dequantize_row_iq1_s_cuda;
case GGML_TYPE_IQ1_M:
return dequantize_row_iq1_m_cuda;
case GGML_TYPE_IQ4_NL:
return dequantize_row_iq4_nl_cuda;
case GGML_TYPE_IQ4_XS:
return dequantize_row_iq4_xs_cuda;
case GGML_TYPE_IQ3_S:
return dequantize_row_iq3_s_cuda;
case GGML_TYPE_MXFP4:
return dequantize_row_mxfp4_cuda;
case GGML_TYPE_F16:
return convert_unary_cont_cuda<half>;
case GGML_TYPE_BF16:
return convert_unary_cont_cuda<nv_bfloat16>;
case GGML_TYPE_F32:
return convert_unary_cont_cuda<float>;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please keep the order or types the same as their definition in the enum.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we talking about the ggml_type enum? None of the three types are in the same order. Should I update them all to be consistent?

My change was just copying the f32 type

$ diff bf16 f32
46,47d45
<         case GGML_TYPE_F32:
<             return convert_unary_cont_cuda<float>;

and adding the ggml_type_f32 case.

Happy to fix up the cases as appropriate, just let me know how you want to proceed.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Regardless of what order types were in before, when you do a switch like this, please use the order as declared in the ggml_type in ggml.h.

@IMbackK
Copy link
Collaborator

IMbackK commented Aug 12, 2025

I'm the one who originally wrote the FP32 and FP16 vector kernels for CUDA. However, as of right now basically only the FP32 version is used. I found that in terms of speed they are essentially the same so it generally doesn't make sense to use the FP16 version; I am keeping and maintaining the FP16 version in case it's useful in the future but in retrospect it was not worthwhile to add.

At least when building for amd targets that support packed fp16 math (~dual issue) the fp16 kernel results in packed instructions.
Im not sure this translates into anything but at least for architectures that can do packed math (amd gfx9+, volta) there may be a benefit.

@JohannesGaessler
Copy link
Collaborator

In principle FP16 math is also faster for NVIDIA GPUs. But the catch is that the vector kernels are intended for very low numbers of tokens where the kernel is going to be I/O bound anyways so whether or not the compute is efficient doesn't really matter.

@IMbackK
Copy link
Collaborator

IMbackK commented Aug 12, 2025

probably still saves some power but yeah your probably right that its not worth it. we did also use the vec kernels for a time on amd gpus for all sequence lengths

@eous
Copy link
Author

eous commented Aug 12, 2025

I might be missing some logic branches, but on my system, it appears to fall back to the f16-mma kernel path by default when K or V types are not specified. A significant part of my work during a recent marathon coding session was dedicated to bringing the BF16 vector kernel's performance closer to that of the FP16 path.

For context, here is the initial performance comparison:

Default path (FP16): 3900 prompt tokens/s, 185 generation tokens/s
Early BF16 CUDA path: 1900 prompt tokens/s, 170 generation tokens/s

After finding some fun tricks/optimizations I decided to share the joy and submitted the pull request (after rewriting the changeset to align with the existing style). The two paths (default-f16 and bf16) are nearly on par for token generation, though a performance penalty for BF16 remains during prompt processing.

Here is the current state (single run, one GPU, not averaged):

Prompt processing: BF16 (3199.60 tokens/s) vs. FP16 (5351.77 tokens/s)
Token generation: BF16 (191.52 tokens/s) vs. FP16 (192.98 tokens/s)

The original BF16 performance penalty (that currently exists in mainline) appears specific to Flash Attention being enabled. When running a one-off test on the mainline branch with Flash Attention disabled, the performance is virtually identical: BF16 at 168.71 tokens/s versus FP16 at 168.93 tokens/s, and bf16 10 to 15 t/s when fattn is enabled and threads set to 1. This suggests the performance gap could likely be addressed elsewhere, without needing this dedicated kernel.

Unfortunately, while I have some C++/CUDA experience, I'm still unfamiliar with the ggml-cuda codebase and was unable to track down the root cause. I opted for this kernel-based approach as a more direct way to solve the immediate problem.

I agree that consolidating these kernels is the right move. While it might slightly increase the complexity of a single kernel, it will reduce the overall code surface area and duplication easing maintenance.

I'm happy to take on this task, but it will likely take me some time to become proficient enough with the broader ggml-cuda code base to implement it correctly. Ultimately I just want to see more love for bf16 given how frequently I find models using it.

@JohannesGaessler
Copy link
Collaborator

a performance penalty for BF16 remains during prompt processing.

The "vector" kernels are only intended to be used for batch sizes <= 8. For large batch sizes (and some GPUs for models with GQA) the kernel in fattn-mma-f16.cuh is used. The mma kernel only supports FP16 so the BF16 KV cache is converted to FP16 first which costs a lot of time. In principle it would be possible to extend the mma kernel with BF16 support (needs Ampere or newer for the BF16 tensor cores).

This suggests the performance gap could likely be addressed elsewhere, without needing this dedicated kernel.

If you disable FlashAttention the attention is handled via GEMM -> softmax -> GEMM, llama.cpp has GEMM support for BF16. The speed is the same because there are FP16 and BF16 variants of GEMM. It is not possible to get the FP16 speeds with FA enabled without extending the FA kernels with BF16 support.

I'm happy to take on this task, but it will likely take me some time to become proficient enough with the broader ggml-cuda code base to implement it correctly.

It's fine to just do it at your own pace.

Ultimately I just want to see more love for bf16 given how frequently I find models using it.

I have yet to see any evidence that BF16 is superior to FP16 except for cases where the numerical range of FP16 is insufficient. The FA kernels use FP32 accumulation for KQ so I don't think there would be a meaningful difference.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
examples ggml changes relating to the ggml tensor library for machine learning Nvidia GPU Issues specific to Nvidia GPUs python python script changes
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants