Skip to content

vulkan: optimize rms_norm, and allow the work to spread across multiple SMs #15281

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: master
Choose a base branch
from

Conversation

jeffbolznv
Copy link
Collaborator

There are really two parts to this change:
(1) Some optimizations similar to what we have in soft_max, to unroll with different numbers of iterations.
(2) A fusion optimization where we detect add followed by rms_norm, and make the add shader atomically accumulate the values^2 into memory. Then the rms_norm shader can just load that sum. This allows the rms_norm to be parallelized across multiple workgroups, it just becomes a simple per-element multiply.

The fusion optimization is currently only applied when the rms_norm is on a single vector. This previously always ran on a single SM. It could apply more broadly, but when there are other dimensions the work can already spread across SMs, and there would be some complexity to tracking multiple atomic sums.

Perf results below. As expected, bigger gains on a bigger GPU, because the serial cost of rms_norm is more pronounced.

5090 before:

Z:\github\jeffbolznv\llama.cpp\build\bin\RelWithDebInfo>llama-bench.exe -fa 1 -n 128 -p 0 -r 10 --prio 1 -m c:\models\DeepSeek-R1-Distill-Llama-8B-Q4_K_M.gguf -m c:\models\DeepSeek-R1-Distill-Llama-8B-Q6_K.gguf -m c:\models\DeepSeek-R1-Distill-Qwen-14B-Q4_K_M.gguf -m c:\models\Llama-3.2-1B.Q2_K.gguf -m c:\models\Llama-3.2-1B.Q3_K_S.gguf -m c:\models\llama-3.2-3b-instruct-q5_k_m.gguf -m c:\models\Qwen_Qwen3-30B-A3B-Q2_K.gguf -m c:\models\Qwen2.5-7B-Instruct-1M-Q2_K.gguf  -m c:\models\\deepseek-v2-lite-safetensors\deepseek-v2-lite-Q4_K_M.gguf -m c:\models\gpt-oss-20b-mxfp4.gguf -m c:\models\Phi-3-mini-4k-instruct-q4.gguf -m c:\models\llama-2-7b.Q4_0.gguf -m c:\models\llama-3.2-3b-instruct-q8_0.gguf -m c:\models\Mistral-22B-v0.2-Q4_K_M.gguf -m c:\models\nvidia_Llama-3_3-Nemotron-Super-49B-v1_5-Q4_K_S.gguf
ggml_vulkan: Found 1 Vulkan devices:
ggml_vulkan: 0 = NVIDIA GeForce RTX 5090 (NVIDIA) | uma: 0 | fp16: 1 | bf16: 1 | warp size: 32 | shared memory: 49152 | int dot: 1 | matrix cores: NV_coopmat2
| model                          |       size |     params | backend    | ngl | fa |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | -: | --------------: | -------------------: |
| llama 8B Q4_K - Medium         |   4.58 GiB |     8.03 B | Vulkan     |  99 |  1 |           tg128 |        194.72 ± 0.92 |
| llama 8B Q6_K                  |   6.14 GiB |     8.03 B | Vulkan     |  99 |  1 |           tg128 |        168.21 ± 6.30 |
| qwen2 14B Q4_K - Medium        |   8.37 GiB |    14.77 B | Vulkan     |  99 |  1 |           tg128 |        110.16 ± 1.59 |
| llama 1B Q2_K - Medium         | 546.50 MiB |     1.24 B | Vulkan     |  99 |  1 |           tg128 |        555.06 ± 4.54 |
| llama 1B Q3_K - Small          | 604.50 MiB |     1.24 B | Vulkan     |  99 |  1 |           tg128 |        504.33 ± 4.61 |
| llama 3B Q5_K - Medium         |   2.16 GiB |     3.21 B | Vulkan     |  99 |  1 |           tg128 |        293.09 ± 2.12 |
| qwen3moe 30B.A3B Q2_K - Medium |  10.15 GiB |    30.53 B | Vulkan     |  99 |  1 |           tg128 |        178.77 ± 1.90 |
| qwen2 7B Q2_K - Medium         |   2.80 GiB |     7.62 B | Vulkan     |  99 |  1 |           tg128 |        200.01 ± 3.37 |
| deepseek2 16B Q4_K - Medium    |   9.65 GiB |    15.71 B | Vulkan     |  99 |  1 |           tg128 |        231.56 ± 5.79 |
| gpt-oss ?B MXFP4 MoE           |  11.27 GiB |    20.91 B | Vulkan     |  99 |  1 |           tg128 |        211.15 ± 5.57 |
| phi3 3B Q4_K - Medium          |   2.23 GiB |     3.82 B | Vulkan     |  99 |  1 |           tg128 |        290.32 ± 2.51 |
| llama 7B Q4_0                  |   3.56 GiB |     6.74 B | Vulkan     |  99 |  1 |           tg128 |        223.46 ± 1.36 |
| llama 3B Q8_0                  |   3.18 GiB |     3.21 B | Vulkan     |  99 |  1 |           tg128 |        268.62 ± 1.38 |
| llama ?B Q4_K - Medium         |  12.42 GiB |    22.24 B | Vulkan     |  99 |  1 |           tg128 |         78.73 ± 1.33 |
| deci 70B Q4_K - Small          |  26.66 GiB |    49.87 B | Vulkan     |  99 |  1 |           tg128 |         41.45 ± 0.09 |

5090 after:

ggml_vulkan: 0 = NVIDIA GeForce RTX 5090 (NVIDIA) | uma: 0 | fp16: 1 | bf16: 1 | warp size: 32 | shared memory: 49152 | int dot: 1 | matrix cores: NV_coopmat2
| model                          |       size |     params | backend    | ngl | fa |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | -: | --------------: | -------------------: |
| llama 8B Q4_K - Medium         |   4.58 GiB |     8.03 B | Vulkan     |  99 |  1 |           tg128 |        206.97 ± 0.41 |
| llama 8B Q6_K                  |   6.14 GiB |     8.03 B | Vulkan     |  99 |  1 |           tg128 |        177.54 ± 1.72 |
| qwen2 14B Q4_K - Medium        |   8.37 GiB |    14.77 B | Vulkan     |  99 |  1 |           tg128 |        116.34 ± 1.40 |
| llama 1B Q2_K - Medium         | 546.50 MiB |     1.24 B | Vulkan     |  99 |  1 |           tg128 |        576.28 ± 5.21 |
| llama 1B Q3_K - Small          | 604.50 MiB |     1.24 B | Vulkan     |  99 |  1 |           tg128 |        521.25 ± 3.71 |
| llama 3B Q5_K - Medium         |   2.16 GiB |     3.21 B | Vulkan     |  99 |  1 |           tg128 |        309.44 ± 2.29 |
| qwen3moe 30B.A3B Q2_K - Medium |  10.15 GiB |    30.53 B | Vulkan     |  99 |  1 |           tg128 |        184.65 ± 1.86 |
| qwen2 7B Q2_K - Medium         |   2.80 GiB |     7.62 B | Vulkan     |  99 |  1 |           tg128 |        209.20 ± 2.47 |
| deepseek2 16B Q4_K - Medium    |   9.65 GiB |    15.71 B | Vulkan     |  99 |  1 |           tg128 |        236.89 ± 3.15 |
| gpt-oss ?B MXFP4 MoE           |  11.27 GiB |    20.91 B | Vulkan     |  99 |  1 |           tg128 |        217.14 ± 6.34 |
| phi3 3B Q4_K - Medium          |   2.23 GiB |     3.82 B | Vulkan     |  99 |  1 |           tg128 |        308.29 ± 2.10 |
| llama 7B Q4_0                  |   3.56 GiB |     6.74 B | Vulkan     |  99 |  1 |           tg128 |        236.79 ± 1.78 |
| llama 3B Q8_0                  |   3.18 GiB |     3.21 B | Vulkan     |  99 |  1 |           tg128 |        280.35 ± 1.97 |
| llama ?B Q4_K - Medium         |  12.42 GiB |    22.24 B | Vulkan     |  99 |  1 |           tg128 |         83.27 ± 0.36 |
| deci 70B Q4_K - Small          |  26.66 GiB |    49.87 B | Vulkan     |  99 |  1 |           tg128 |         43.49 ± 0.18 |

4070 before:
ggml_vulkan: 0 = NVIDIA GeForce RTX 4070 (NVIDIA) | uma: 0 | fp16: 1 | bf16: 1 | warp size: 32 | shared memory: 49152 | int dot: 1 | matrix cores: NV_coopmat2
| model                          |       size |     params | backend    | ngl | fa |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | -: | --------------: | -------------------: |
| llama 8B Q4_K - Medium         |   4.58 GiB |     8.03 B | Vulkan     |  99 |  1 |           tg128 |         87.14 ± 0.14 |
| llama 8B Q6_K                  |   6.14 GiB |     8.03 B | Vulkan     |  99 |  1 |           tg128 |         68.34 ± 0.98 |
| qwen2 14B Q4_K - Medium        |   8.37 GiB |    14.77 B | Vulkan     |  99 |  1 |           tg128 |         47.22 ± 0.08 |
| llama 1B Q2_K - Medium         | 546.50 MiB |     1.24 B | Vulkan     |  99 |  1 |           tg128 |        402.17 ± 1.33 |
| llama 1B Q3_K - Small          | 604.50 MiB |     1.24 B | Vulkan     |  99 |  1 |           tg128 |        371.73 ± 1.93 |
| llama 3B Q5_K - Medium         |   2.16 GiB |     3.21 B | Vulkan     |  99 |  1 |           tg128 |        155.94 ± 0.30 |
| qwen3moe 30B.A3B Q2_K - Medium |  10.15 GiB |    30.53 B | Vulkan     |  99 |  1 |           tg128 |        132.85 ± 3.87 |
| qwen2 7B Q2_K - Medium         |   2.80 GiB |     7.62 B | Vulkan     |  99 |  1 |           tg128 |        102.39 ± 0.78 |
| deepseek2 16B Q4_K - Medium    |   9.65 GiB |    15.71 B | Vulkan     |  99 |  1 |           tg128 |        158.93 ± 7.72 |
| gpt-oss ?B MXFP4 MoE           |  11.27 GiB |    20.91 B | Vulkan     |  99 |  1 |           tg128 |        118.34 ± 5.01 |
| phi3 3B Q4_K - Medium          |   2.23 GiB |     3.82 B | Vulkan     |  99 |  1 |           tg128 |        149.11 ± 0.94 |
| llama 7B Q4_0                  |   3.56 GiB |     6.74 B | Vulkan     |  99 |  1 |           tg128 |        101.14 ± 0.31 |
| llama 3B Q8_0                  |   3.18 GiB |     3.21 B | Vulkan     |  99 |  1 |           tg128 |        117.59 ± 0.10 |

4070 after:
ggml_vulkan: 0 = NVIDIA GeForce RTX 4070 (NVIDIA) | uma: 0 | fp16: 1 | bf16: 1 | warp size: 32 | shared memory: 49152 | int dot: 1 | matrix cores: NV_coopmat2
| model                          |       size |     params | backend    | ngl | fa |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | -: | --------------: | -------------------: |
| llama 8B Q4_K - Medium         |   4.58 GiB |     8.03 B | Vulkan     |  99 |  1 |           tg128 |         88.72 ± 0.13 |
| llama 8B Q6_K                  |   6.14 GiB |     8.03 B | Vulkan     |  99 |  1 |           tg128 |         69.35 ± 1.15 |
| qwen2 14B Q4_K - Medium        |   8.37 GiB |    14.77 B | Vulkan     |  99 |  1 |           tg128 |         47.59 ± 0.48 |
| llama 1B Q2_K - Medium         | 546.50 MiB |     1.24 B | Vulkan     |  99 |  1 |           tg128 |        403.74 ± 3.95 |
| llama 1B Q3_K - Small          | 604.50 MiB |     1.24 B | Vulkan     |  99 |  1 |           tg128 |        376.70 ± 2.96 |
| llama 3B Q5_K - Medium         |   2.16 GiB |     3.21 B | Vulkan     |  99 |  1 |           tg128 |        158.07 ± 0.29 |
| qwen3moe 30B.A3B Q2_K - Medium |  10.15 GiB |    30.53 B | Vulkan     |  99 |  1 |           tg128 |        135.86 ± 1.19 |
| qwen2 7B Q2_K - Medium         |   2.80 GiB |     7.62 B | Vulkan     |  99 |  1 |           tg128 |        103.56 ± 1.38 |
| deepseek2 16B Q4_K - Medium    |   9.65 GiB |    15.71 B | Vulkan     |  99 |  1 |           tg128 |        163.34 ± 0.43 |
| gpt-oss ?B MXFP4 MoE           |  11.27 GiB |    20.91 B | Vulkan     |  99 |  1 |           tg128 |        120.45 ± 3.62 |
| phi3 3B Q4_K - Medium          |   2.23 GiB |     3.82 B | Vulkan     |  99 |  1 |           tg128 |        152.27 ± 0.90 |
| llama 7B Q4_0                  |   3.56 GiB |     6.74 B | Vulkan     |  99 |  1 |           tg128 |        102.24 ± 0.30 |
| llama 3B Q8_0                  |   3.18 GiB |     3.21 B | Vulkan     |  99 |  1 |           tg128 |        118.71 ± 0.26 |

…le SMs

There are really two parts to this change:
(1) Some optimizations similar to what we have in soft_max, to unroll with
different numbers of iterations.
(2) A fusion optimization where we detect add followed by rms_norm, and make
the add shader atomically accumulate the values^2 into memory. Then the
rms_norm shader can just load that sum. This allows the rms_norm to be
parallelized across multiple workgroups, it just becomes a simple per-element
multiply.

The fusion optimization is currently only applied when the rms_norm is on a
single vector. This previously always ran on a single SM. It could apply more
broadly, but when there are other dimensions the work can already spread across
SMs, and there would be some complexity to tracking multiple atomic sums.
@jeffbolznv jeffbolznv requested a review from 0cc4m as a code owner August 13, 2025 04:23
@github-actions github-actions bot added testing Everything test related Vulkan Issues specific to the Vulkan backend ggml changes relating to the ggml tensor library for machine learning labels Aug 13, 2025
@jeffbolznv jeffbolznv marked this pull request as draft August 13, 2025 04:23
@jeffbolznv
Copy link
Collaborator Author

Set to draft because there will be an interaction with #15252 when it's merged.

if (p.param3 != 0) {
sum_sq = subgroupAdd(sum_sq);
if (sum_sq != 0 && gl_SubgroupInvocationID == 0) {
atomicAdd(data_atom, sum_sq);
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just want to point out that this potentially introduces a bit of nondeterminism due to floating point addition not being associative. I don't expect it to be a problem, just want to mention in case anybody is concerned.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm, it's not a good idea to introduce nondeterminism in the computations. Are there alternatives?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Second commit changes this to write out a partial sum for each workgroup, and the rms_norm shader adds them up, so it's a deterministic order now.

rather than using atomic add, to make it deterministic. The rms_norm
shader fetches a subgroup's worth in parallel and uses subgroupAdd to
add them up.
@jeffbolznv jeffbolznv force-pushed the rms_norm_atomic_add branch from e0b01db to 075dac2 Compare August 13, 2025 15:07
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ggml changes relating to the ggml tensor library for machine learning testing Everything test related Vulkan Issues specific to the Vulkan backend
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants