Skip to content

Conversation

jeffbolznv
Copy link
Collaborator

row_ids only needs to hold the BN rows for the current tile. This reduces the shared memory usage and also the need for the batch splitting.

I didn't expect this to have such a positive impact on performance. I'm not sure whether this is due to short-circuiting the row_id search, or allowing more workgroups to run concurrently, or just reducing shared memory traffic. I don't think we were hitting the batch splitting with pp512 for any of these models.

Z:\github\jeffbolznv\llama.cpp\build\bin\RelWithDebInfo>llama-bench.exe -fa 1 -n 0 -p 512 -r 100 --prio 1 -m c:\models\Qwen_Qwen3-30B-A3B-Q2_K.gguf -m c:\models\\deepseek-v2-lite-safetensors\deepseek-v2-lite-Q4_K_M.gguf -m c:\models\gpt-oss-20b-mxfp4.gguf
ggml_vulkan: Found 1 Vulkan devices:
ggml_vulkan: 0 = NVIDIA GeForce RTX 5090 (NVIDIA) | uma: 0 | fp16: 1 | bf16: 1 | warp size: 32 | shared memory: 49152 | int dot: 1 | matrix cores: NV_coopmat2
| model                          |       size |     params | backend    | ngl | fa |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | -: | --------------: | -------------------: |
| qwen3moe 30B.A3B Q2_K - Medium |  10.15 GiB |    30.53 B | Vulkan     |  99 |  1 |           pp512 |      6301.14 ± 63.17 |
| deepseek2 16B Q4_K - Medium    |   9.65 GiB |    15.71 B | Vulkan     |  99 |  1 |           pp512 |     8465.49 ± 145.87 |
| gpt-oss 20B MXFP4 MoE          |  11.27 GiB |    20.91 B | Vulkan     |  99 |  1 |           pp512 |     7591.19 ± 159.70 |

5090 after

Z:\github\jeffbolznv\llama.cpp\build\bin\RelWithDebInfo>llama-bench.exe -fa 1 -n 0 -p 512 -r 100 --prio 1 -m c:\models\Qwen_Qwen3-30B-A3B-Q2_K.gguf -m c:\models\\deepseek-v2-lite-safetensors\deepseek-v2-lite-Q4_K_M.gguf -m c:\models\gpt-oss-20b-mxfp4.gguf
ggml_vulkan: Found 1 Vulkan devices:
ggml_vulkan: 0 = NVIDIA GeForce RTX 5090 (NVIDIA) | uma: 0 | fp16: 1 | bf16: 1 | warp size: 32 | shared memory: 49152 | int dot: 1 | matrix cores: NV_coopmat2
| model                          |       size |     params | backend    | ngl | fa |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | -: | --------------: | -------------------: |
| qwen3moe 30B.A3B Q2_K - Medium |  10.15 GiB |    30.53 B | Vulkan     |  99 |  1 |           pp512 |      6952.84 ± 75.11 |
| deepseek2 16B Q4_K - Medium    |   9.65 GiB |    15.71 B | Vulkan     |  99 |  1 |           pp512 |     9638.35 ± 232.51 |
| gpt-oss 20B MXFP4 MoE          |  11.27 GiB |    20.91 B | Vulkan     |  99 |  1 |           pp512 |     8838.68 ± 131.57 |

4070 before

Z:\github\jeffbolznv\llama.cpp\build\bin\RelWithDebInfo>llama-bench.exe -fa 1 -n 0 -p 512 -r 100 --prio 1 -m c:\models\Qwen_Qwen3-30B-A3B-Q2_K.gguf -m c:\models\\deepseek-v2-lite-safetensors\deepseek-v2-lite-Q4_K_M.gguf -m c:\models\gpt-oss-20b-mxfp4.gguf
ggml_vulkan: Found 1 Vulkan devices:
ggml_vulkan: 0 = NVIDIA GeForce RTX 4070 (NVIDIA) | uma: 0 | fp16: 1 | bf16: 1 | warp size: 32 | shared memory: 49152 | int dot: 1 | matrix cores: NV_coopmat2
| model                          |       size |     params | backend    | ngl | fa |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | -: | --------------: | -------------------: |
| qwen3moe 30B.A3B Q2_K - Medium |  10.15 GiB |    30.53 B | Vulkan     |  99 |  1 |           pp512 |      2237.35 ± 21.90 |
| deepseek2 16B Q4_K - Medium    |   9.65 GiB |    15.71 B | Vulkan     |  99 |  1 |           pp512 |      2726.12 ± 15.25 |
| gpt-oss 20B MXFP4 MoE          |  11.27 GiB |    20.91 B | Vulkan     |  99 |  1 |           pp512 |      2994.28 ± 28.41 |

4070 after

Z:\github\jeffbolznv\llama.cpp\build\bin\RelWithDebInfo>llama-bench.exe -fa 1 -n 0 -p 512 -r 100 --prio 1 -m c:\models\Qwen_Qwen3-30B-A3B-Q2_K.gguf -m c:\models\\deepseek-v2-lite-safetensors\deepseek-v2-lite-Q4_K_M.gguf -m c:\models\gpt-oss-20b-mxfp4.gguf
ggml_vulkan: Found 1 Vulkan devices:
ggml_vulkan: 0 = NVIDIA GeForce RTX 4070 (NVIDIA) | uma: 0 | fp16: 1 | bf16: 1 | warp size: 32 | shared memory: 49152 | int dot: 1 | matrix cores: NV_coopmat2
| model                          |       size |     params | backend    | ngl | fa |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | -: | --------------: | -------------------: |
| qwen3moe 30B.A3B Q2_K - Medium |  10.15 GiB |    30.53 B | Vulkan     |  99 |  1 |           pp512 |      2529.67 ± 22.86 |
| deepseek2 16B Q4_K - Medium    |   9.65 GiB |    15.71 B | Vulkan     |  99 |  1 |           pp512 |      3497.69 ± 15.20 |
| gpt-oss 20B MXFP4 MoE          |  11.27 GiB |    20.91 B | Vulkan     |  99 |  1 |           pp512 |      3427.69 ± 32.72 |

4070 coopmat1 before

Z:\github\jeffbolznv\llama.cpp\build\bin\RelWithDebInfo>llama-bench.exe -fa 1 -n 0 -p 512 -r 100 --prio 1 -m c:\models\Qwen_Qwen3-30B-A3B-Q2_K.gguf -m c:\models\\deepseek-v2-lite-safetensors\deepseek-v2-lite-Q4_K_M.gguf -m c:\models\gpt-oss-20b-mxfp4.gguf
ggml_vulkan: Found 1 Vulkan devices:
ggml_vulkan: 0 = NVIDIA GeForce RTX 4070 (NVIDIA) | uma: 0 | fp16: 1 | bf16: 1 | warp size: 32 | shared memory: 49152 | int dot: 1 | matrix cores: KHR_coopmat
| model                          |       size |     params | backend    | ngl | fa |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | -: | --------------: | -------------------: |
| qwen3moe 30B.A3B Q2_K - Medium |  10.15 GiB |    30.53 B | Vulkan     |  99 |  1 |           pp512 |      1981.81 ± 17.15 |
| deepseek2 16B Q4_K - Medium    |   9.65 GiB |    15.71 B | Vulkan     |  99 |  1 |           pp512 |       2240.58 ± 7.06 |
| gpt-oss 20B MXFP4 MoE          |  11.27 GiB |    20.91 B | Vulkan     |  99 |  1 |           pp512 |       1392.39 ± 9.91 |

4070 coopmat1 after

Z:\github\jeffbolznv\llama.cpp\build\bin\RelWithDebInfo>llama-bench.exe -fa 1 -n 0 -p 512 -r 100 --prio 1 -m c:\models\Qwen_Qwen3-30B-A3B-Q2_K.gguf -m c:\models\\deepseek-v2-lite-safetensors\deepseek-v2-lite-Q4_K_M.gguf -m c:\models\gpt-oss-20b-mxfp4.gguf
ggml_vulkan: Found 1 Vulkan devices:
ggml_vulkan: 0 = NVIDIA GeForce RTX 4070 (NVIDIA) | uma: 0 | fp16: 1 | bf16: 1 | warp size: 32 | shared memory: 49152 | int dot: 1 | matrix cores: KHR_coopmat
| model                          |       size |     params | backend    | ngl | fa |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | -: | --------------: | -------------------: |
| qwen3moe 30B.A3B Q2_K - Medium |  10.15 GiB |    30.53 B | Vulkan     |  99 |  1 |           pp512 |      2219.72 ± 17.82 |
| deepseek2 16B Q4_K - Medium    |   9.65 GiB |    15.71 B | Vulkan     |  99 |  1 |           pp512 |       2471.60 ± 8.60 |
| gpt-oss 20B MXFP4 MoE          |  11.27 GiB |    20.91 B | Vulkan     |  99 |  1 |           pp512 |      2177.02 ± 16.05 |

row_ids only needs to hold the BN rows for the current tile.
@jeffbolznv jeffbolznv requested a review from 0cc4m as a code owner August 25, 2025 16:02
@github-actions github-actions bot added testing Everything test related Vulkan Issues specific to the Vulkan backend ggml changes relating to the ggml tensor library for machine learning labels Aug 25, 2025
@netrunnereve
Copy link
Collaborator

Wow I'm getting a huge 50% improvement on my W8100 + RX 470.

PR:

model size params backend ngl threads main_gpu test t/s
gpt-oss 20B MXFP4 MoE 11.27 GiB 20.91 B Vulkan 100 8 1 pp512 235.11 ± 0.35
gpt-oss 20B MXFP4 MoE 11.27 GiB 20.91 B Vulkan 100 8 1 tg128 35.62 ± 0.29

Master:

model size params backend ngl threads main_gpu test t/s
gpt-oss 20B MXFP4 MoE 11.27 GiB 20.91 B Vulkan 100 8 1 pp512 153.46 ± 0.13
gpt-oss 20B MXFP4 MoE 11.27 GiB 20.91 B Vulkan 100 8 1 tg128 35.55 ± 0.21

or just reducing shared memory traffic

I remember I tried doing the iq4_nl lut using subgroup shuffles instead of shared memory a while back and while it didn't make a difference with mat vec I got what I think was a 10-20% improvement with mat mul. Considering how often mat mul accesses shared memory likely shared memory traffic was the reason for that.

@0cc4m
Copy link
Collaborator

0cc4m commented Aug 26, 2025

I wish I had thought of that, that's very smart. Decent improvements all around. Coopmat1 gpt-oss liked it the most, any ideas why?

Nvidia RTX 3090 (coopmat2)

model size params backend ngl fa test t/s (before) t/s (after) diff
qwen3moe 30B.A3B Q2_K - Medium 10.48 GiB 30.53 B Vulkan 99 0 pp512 2293.41 ± 35.52 2481.26 ± 21.39 +8.2%
qwen3moe 30B.A3B Q2_K - Medium 10.48 GiB 30.53 B Vulkan 99 1 pp512 2377.18 ± 17.55 2599.78 ± 22.57 +9.4%
gpt-oss 20B Q8_0 11.27 GiB 20.91 B Vulkan 99 0 pp512 3279.56 ± 57.32 3568.76 ± 41.80 +8.8%
gpt-oss 20B Q8_0 11.27 GiB 20.91 B Vulkan 99 1 pp512 3519.89 ± 22.75 3860.36 ± 24.65 +9.7%

Nvidia RTX 3090 (coopmat)

model size params backend ngl fa test t/s (before) t/s (after) diff
qwen3moe 30B.A3B Q2_K - Medium 10.48 GiB 30.53 B Vulkan 99 0 pp512 2189.03 ± 20.43 2361.82 ± 30.96 +7.9%
qwen3moe 30B.A3B Q2_K - Medium 10.48 GiB 30.53 B Vulkan 99 1 pp512 2198.75 ± 16.14 2391.72 ± 18.72 +8.8%
gpt-oss 20B Q8_0 11.27 GiB 20.91 B Vulkan 99 0 pp512 1610.44 ± 10.42 2302.31 ± 14.98 +43.0%
gpt-oss 20B Q8_0 11.27 GiB 20.91 B Vulkan 99 1 pp512 1639.42 ± 7.97 2358.64 ± 10.95 +43.9%

AMD Radeon Pro VII

model size params backend ngl fa test t/s (before) t/s (after) diff
qwen3moe 30B.A3B Q2_K - Medium 10.48 GiB 30.53 B Vulkan 99 0 pp512 349.68 ± 2.09 391.56 ± 2.25 +12.0%
qwen3moe 30B.A3B Q2_K - Medium 10.48 GiB 30.53 B Vulkan 99 1 pp512 326.97 ± 1.56 364.71 ± 1.92 +11.5%
gpt-oss 20B Q8_0 11.27 GiB 20.91 B Vulkan 99 0 pp512 518.69 ± 6.37 563.19 ± 5.16 +8.6%
gpt-oss 20B Q8_0 11.27 GiB 20.91 B Vulkan 99 1 pp512 509.54 ± 4.28 560.03 ± 4.40 +9.9%

AMD RX 6800 XT

model size params backend ngl fa test t/s (before) t/s (after) diff
qwen3moe 30B.A3B Q2_K - Medium 10.48 GiB 30.53 B Vulkan 99 0 pp512 785.65 ± 4.09 864.84 ± 4.83 +10.1%
qwen3moe 30B.A3B Q2_K - Medium 10.48 GiB 30.53 B Vulkan 99 1 pp512 765.07 ± 3.75 835.91 ± 4.26 +9.3%
gpt-oss 20B Q8_0 11.27 GiB 20.91 B Vulkan 99 0 pp512 1206.78 ± 17.16 1279.94 ± 16.99 +6.1%
gpt-oss 20B Q8_0 11.27 GiB 20.91 B Vulkan 99 1 pp512 1218.56 ± 8.84 1293.59 ± 10.41 +6.2%

Intel A770

model size params backend ngl fa test t/s (before) t/s (after) diff
qwen3moe 30B.A3B Q2_K - Medium 10.48 GiB 30.53 B Vulkan 99 0 pp512 126.93 ± 0.66 134.75 ± 0.78 +6.2%
qwen3moe 30B.A3B Q2_K - Medium 10.48 GiB 30.53 B Vulkan 99 1 pp512 82.24 ± 0.30 85.61 ± 0.32 +4.1%
gpt-oss 20B Q8_0 11.27 GiB 20.91 B Vulkan 99 0 pp512 177.44 ± 2.74 184.11 ± 1.31 +3.8%
gpt-oss 20B Q8_0 11.27 GiB 20.91 B Vulkan 99 1 pp512 175.73 ± 1.44 182.80 ± 2.39 +4.0%

@0cc4m 0cc4m merged commit 34bdbbd into ggml-org:master Aug 26, 2025
48 checks passed
@jeffbolznv
Copy link
Collaborator Author

Decent improvements all around. Coopmat1 gpt-oss liked it the most, any ideas why?

I think this might have taken NV coopmat1 from the medium to large tile size, but I didn't verify.

Minh141120 pushed a commit to menloresearch/llama.cpp that referenced this pull request Aug 27, 2025
row_ids only needs to hold the BN rows for the current tile.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ggml changes relating to the ggml tensor library for machine learning testing Everything test related Vulkan Issues specific to the Vulkan backend

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants