-
Notifications
You must be signed in to change notification settings - Fork 13.4k
vulkan: Add VK_NV_cooperative_matrix2 support for mul_mat and FlashAttention2 #10206
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Wow, this looks really impressive. I'll take a closer look later. In the interest of compatibility, how much work do you think it is to also add a VK_KHR_cooperative_matrix codepath, for AMD RDNA3 and maybe Intel ARC? Obviously not something you have to do, but I'd have to try it sooner or later. |
It may not be too bad in the mul_mm shader since you already have the code to dequantize and copy to shared memory. Porting the flash attention shader would be more involved, both because there's no existing scalar shader to start from, and also because it does reductions and needs to clear padding elements, and there's no way to do that other than to dump the matrix out to shared memory. |
Very impressive. Just to let you know, Mesa's amdgpu driver is implementing a polyfill for pre-RDNA3 generation, so VK_KHR_cooperative_matrix optimization could benefit all AMD GPUs running Linux. |
Probably not, emulating the matrix operations will likely just be slower than a non-coopmat codepath. You can see an example of that at the bottom of the merge request. |
There's more parts to this that could be upstreamed separately, before the driver and Vulkan support for the coopmat2 extension lands. I'm thinking of the vector copy shader and the matrix vector multiplication shader improvements. Additionally, if we upstream the f16acc and f32acc switch earlier we could adapt my matrix multiplication shader to it. I had previously hardcoded the accumulator to float due to precision issues. In the meantime I'll look into a |
I'm currently working on the optimized copy shader and will make a PR soon. I had trouble reproducing the gains from the mat-vec mul shader in isolation, but I think it may have only really benefited Q8_0. I'll try that again soon. If you're going to look into the KHR path, I did a very basic prototype of it recently that you could use as a starting point: jeffbolznv@3416010. The first real bit of nastiness I ran into is needing to bound-check the store (see the comment). |
@jeffbolznv do you happen to know approximately how much of the tg improvement is due to tensor cores and how much is due to flash attention? Dedicated matrix matrix multipliers probably are going to help a lot with prompt processing but the matrix vector multiplications for interence are limited by memory bandwidth in most cases. From the graphs it looks like the original Vulkan implementation is severely compute bound. |
Oh no, I'm afraid there was a mistake in our testing methodology where we had |
I've replaced the RTX 4070 results with something I think is more correct. Very little gain for tg, still large gain for pp. Will try to get updated RTX 6000 results tomorrow. |
FA for TG starts to make significant difference only at large contexts so this is expected for the |
Does someone have a theory for the DeepSeek Coder V2 Lite outlier in the benchmarks? Why is Vulkan significantly outperforming CUDA there, in tg even without the tensor cores. Some issue with the MUL_MAT_ID implementation? |
Possibly because the expert selection is done on the CPU. |
So I tried out the BTW our mat vec shaders have a subgroup size of 32 which is fine for Nvidia and new AMD cards but bad for my old W8100 😏. |
I'm working on splitting out the mul_mat_vec changes and making some additional optimizations, I hope to make a PR for that tomorrow. I haven't actually seen much gain from the subgroupAdd, and using that if the subgroup size is not equal to 32 is tricky, so I wasn't planning to leave that in. Was it on W8100 that you saw it helped? |
I haven't used the subgroup operations so far cause they caused driver crashes on Intel. But I'll test that again once you make the PR. If that problem is still around we might have to make them optional. GCN cards are definitely relevant, as for many of them Vulkan is the last API they have left. But also modern RDNA defaults to subgroups of size 64, though you can manually reduce that to 32 using |
I wonder what kind of perf would be if rdna3 run at dual issue 32wave vopd. |
Yep I got a 6% improvement on my W8100 by replacing the final barrier sum with something like this. I guess it might be due to the fact that my card is slower at thread synchronization and gets hung up by the two barriers, but I don't know.
Honestly I don't think it's necessary to invest a lot of effort on a 6% gain for old quant types, though it's possible that there's a larger difference on certain GPU models. What I'll probably look into instead is the 64 subgroup size as my card is potentially only using half its capabilities with 32 and that should hopefully have a big improvement. The shaders look pretty simple to hack on as long as I don't need to touch |
The Vulkan performance vs. CUDA with this PR is definitely much better than I would have expected would be possible using a more generic API. I'll benchmark the performance myself once the feature becomes available via package managers.
This is an on-the-fly dequantization to FP16 into SRAM, right? Context: llama.cpp training support is one of my next goals and I think memory use will be the biggest bottleneck. So good performance for small batch sizes would be desirable and dequantization into VRAM would be too slow I think.
The K cache seems to need more precision than the V cache so an asymmetric setup could make sense for users that like fiddling.
That is correct, the CUDA code uses either
Try setting |
I've rebased this and it's a bit more readable now. I still want to split out the types.comp changes, will do after #10387 lands. |
Yes, the intent is that the compiler should be staging the matrices through shared memory. That may not happen in all cases, depends on the implementation, but that's the goal. |
I was originally thinking we'd not merge this until the next Vulkan SDK is released, but I recently realized that Android has a separate SDK on a different release schedule, so the code needs to build against older headers for a while regardless. I've updated things so it can build against older or newer headers. For the shader compiles, it checks the vulkan header for the presence of the extension and assumes glslc will support coopmat2 if the extension is present in the headers. I don't know of a better way to check for support. I've removed the debug code and will remove the "draft" label, I think this is ready for review/merge. I'm fine with waiting for #10597 to be merged first. |
afbf043
to
e23193a
Compare
I'll do a review in a few hours. If no bigger issues come up, I think we can merge this one first, afterall my PR was inspired by this one. |
I managed to update my libvulkan, the Vulkan headers and glslc, so that everything compiles. But I don't have the coopmat2 extension on my RTX 3090 despite using driver |
The Vulkan Developer Beta Drivers are based on a slightly older branch than the current General Release drivers, and so they have a lower version number, but it's the branch we use to release new features and bugfixes for developers without the longer release latency of the general release drivers. So for now, the coopmat2 support is only in the developer beta drivers. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good. It compiled and ran fine on an older driver without the feature and on an older Vulkan version (but obviously didn't give any advantage there).
The latest Vulkan SDK (1.4.304.0) includes support for cooperative matrix 2. A developer driver from https://developer.nvidia.com/vulkan-driver is still necessary. |
Hi, I have recently been informed that this project might need FlashAttention-2 using The only problem that it peaks at 30% of the theoretical flops limit (at least on my RTX 2060 super) with fp16. It might be further optimized or not -- in theory, KHR coopmat can be significantly slower than NV_coopmat as the former does not allow us to perform complex manipulation of the matrices while they reside in registers, so my alg needs to write intermediate results to the shared memory and read back again several times. See the code here: https://github.com/etasnadi/VulkanCooperativeMatrixAttention . I also have a blogpost that summarizes the results: https://etasnadi.com/2025/01/vulkanglsl-implementation-of-the-scaled-dot-product-attention-sdpa-with-vk_khr_cooperativematrix/ |
Hi @etasnadi, I wasn't aware that someone already implemented this, that is very cool! Of course we could use that, I'm happy to work with you to modify your work to run in llama.cpp. One step would be to implement support for quantized weight matrices. Let me know how you want to proceed. |
Hi, that's a good idea. I spent some time with this project in the weekend and I managed to enable coopmat2 on my system so far. Recently I investigated the scheduling logic, because I observed that the flash attention op actually scheduled to run on CPU (instead with coopmat2) because the models I downloaded are not compatible with Jeff's op or flash attention in general. If you can provide some information about the internals that would significantly reduce the development time because now I have to figure out how things work under the hood. Where does this kind of communication happen in the project. Does the project have a slack/discord channel or something? |
My communication with other devs is 99% Github and 1% email. |
Most of the communication happens on Github, but in this case I can help you more directly on Discord, if you want. You can add me there ( |
Thanks. I am more comfortable with Github and e-mail too, so if the communication happens here, I am not opening a slack/discord account. |
Fair enough. Then I suggest you open an issue or draft PR for this topic and put your questions there and I'll do my best to help. |
The r575 game ready drivers supporting NV_cooperative_matrix2 are available now, e.g. version 576.02 for Windows. You may also see a small boost in token generation perf. |
This change adds support for VK_NV_cooperative_matrix2 (https://registry.khronos.org/vulkan/specs/1.3-extensions/html/vkspec.html#VK_NV_cooperative_matrix2, https://github.com/KhronosGroup/GLSL/blob/main/extensions/nv/GLSL_NV_cooperative_matrix2.txt) to the Vulkan backend. This is a recent Vulkan extension supported by NVIDIA drivers that enables matrix multiplies using the tensor cores, while being easier to use and supporting more operations than VK_KHR_cooperative_matrix.
While this PR is code complete and passes testing, it is a Draft for a while because the build system relies on the Vulkan SDK and the tooling for this extension will land in the next Vulkan SDK release (edit: available in SDK version 1.4.304.0). [out of date: In the meantime, if you're interested to try this out locally you can clone https://github.com/KhronosGroup/Vulkan-Headers and grab the latest CI build from https://github.com/google/shaderc and set the cmake variables Vulkan_INCLUDE_DIR to point to Vulkan-Headers/include and Vulkan_GLSLC_EXECUTABLE to point to glslc]. You'll need the most recent Vulkan beta driver from https://developer.nvidia.com/vulkan-driver.
The two main additions in this change are a coopmat2 mul_mat shader and a coopmat2 FlashAttention2 shader. The mul_mat shader supports normal matrix multiples and mixture of experts, and supports a variety of quantization formats using the "decode" callback functionality in coopMatLoadTensorNV. The decode callback functions are in dequant_funcs_cm2.comp and decode one element at a time. The FlashAttention2 shader also supports quantization formats and could theoretically use different formats for K and V, but the compilation cost for supporting all those combinations was too high and I don't know if this is ever used in practice. Note that the mul_mat approach to quantization formats is analogous to the existing Vulkan mul_mat shader in that it converts to fp16 before multiplying, whereas I believe the CUDA path converts to int8 and applies the scale/bias per-tile.
I've also done optimizations of some other shaders, including mul_mat_vec, split_k_reduce, and adding a "linear vec4 copy" shader. With the much higher cooperative matrix2 perf for mul_mat, these other state buckets become relatively more expensive. I'll split these out into smaller changes to review separately. But I wanted to include them all here for context and for perf measurements.
[Edit: Previous performance comparisons against existing Vulkan path were broken. The comparison between Vulkan w/coopmat2 and CUDA I believe was accurate.]
The coopmat2 path helps significantly with prompt processing. FA helps a little bit in token gen, but only a few percent here and there.