-
Notifications
You must be signed in to change notification settings - Fork 13.9k
CUDA: generalized (mma) FA, add Volta support #17505
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
CUDA: generalized (mma) FA, add Volta support #17505
Conversation
48372ef to
2ef0c5f
Compare
|
Thank you for the info, I shall work on FA for RDNA4 once this PR is merged. Looks like that the logic of transposed tile is still empty. |
b92e6f8 to
301ae30
Compare
|
Testing the performance: prefill performance is greatly improved, however, TG is slower. I think it's better to use BEST_FATTN_KERNEL_VEC for tg. On the master branch (7d2add5): ./build-volta/bin/llama-bench -m /models/llm/llama/llama-2-7b.Q4_0.gguf -fa 0,1 -p 512,1024,2048,4096,8192,16384 -n 128,256,512,1024
With this PR merged:
|
|
Thank you for reporting this issue. The performance tuning for LLaMA 2 7b in particular was suboptimal because it's a very old model that doesn't use GQA and I forgot to test that particular scenario. |
|
I see that with some other models as well. For example, Qwen3 14B has slightly lower TG throughput with this PR even though other models are faster/same. With PR:
Without PR:
|
|
OK, I pulled the latest changes, both models are faster now. Qwen3moe 30B-A3B is also slightly faster.
|
This PR makes the following changes to the CUDA FlashAttention code:
mask->ne[1]direction. This is done by applying a modulo on the mask column that is being read so no conditional statements need to be evaluated. The impact on performance is negligible and I do not deem it necessary to compile additional template specializations. See ggml : remove KQ mask padding #16309 . cc @ggerganov .tiletemplate inmma.cuhhas been extended with additional, optional arguments to safely handle situations where tiles of the same shape can have different physical data layouts.__launch_bounds__when using ROCm (as of right now ROCm is not used).K->ne[1]. As with the tile kernel, because this comes at a cost to performance it is still preferable to pad the KV cache length. As of right now this is still required to be 256, for the currently supported GPUs it should be possible to lower this to 128 without issue once the WMMA kernel has been completely replaced. For Hopper it may still make sense to have a padding of 256 but as it is I have no idea whether the 256x64 instruction would actually have better performance than the 128x64 instruction.As of right now the interface in
mma.cuhis suboptimal and long-term I intend to refactor it to allow the use of tensor cores in a more uniform way. However, I don't know the exact requirements until we have proper support for AMD WMMA and AMD MFMA instructions. So for now I think the correct choice is to prioritize getting working support for those at the cost of maintainability and to do a refactor afterwards.V100 performance
Other GPU performance
The performance numbers assume that the KQ mask is no longer being padded. This change is also in this PR. I don't have a good overview of which other backends maybe still need support for this change and whether or not it should be reverted prior to merging.