Skip to content

Conversation

jeffbolznv
Copy link
Collaborator

Track a list of nodes that need synchronization, and only sync if the new node depends on them (or overwrites them). This allows some overlap which can improve performance, and centralizes a big chunk of the synchronization logic.

The remaining synchronization logic involves writes to memory other than the nodes, e.g. for dequantization or split_k. Each of these allocations has a bool indicating whether they were in use and need to be synced. This should be checked before they are written to, and set to true after they are done being consumed.

5090 before

Z:\github\jeffbolznv\llama.cpp\build\bin\RelWithDebInfo>llama-bench.exe -fa 1 -n 128 -p 0 -r 20 --prio 1 -m c:\models\DeepSeek-R1-Distill-Llama-8B-Q4_K_M.gguf -m c:\models\DeepSeek-R1-Distill-Llama-8B-Q6_K.gguf -m c:\models\DeepSeek-R1-Distill-Qwen-14B-Q4_K_M.gguf -m c:\models\Llama-3.2-1B.Q2_K.gguf -m c:\models\Llama-3.2-1B.Q3_K_S.gguf -m c:\models\llama-3.2-3b-instruct-q5_k_m.gguf -m c:\models\Qwen_Qwen3-30B-A3B-Q2_K.gguf -m c:\models\Qwen2.5-7B-Instruct-1M-Q2_K.gguf  -m c:\models\\deepseek-v2-lite-safetensors\deepseek-v2-lite-Q4_K_M.gguf -m c:\models\gpt-oss-20b-mxfp4.gguf -m c:\models\Phi-3-mini-4k-instruct-q4.gguf -m c:\models\llama-2-7b.Q4_0.gguf -m c:\models\llama-3.2-3b-instruct-q8_0.gguf -m c:\models\Mistral-22B-v0.2-Q4_K_M.gguf -m c:\models\nvidia_Llama-3_3-Nemotron-Super-49B-v1_5-Q4_K_S.gguf
ggml_vulkan: Found 1 Vulkan devices:
ggml_vulkan: 0 = NVIDIA GeForce RTX 5090 (NVIDIA) | uma: 0 | fp16: 1 | bf16: 1 | warp size: 32 | shared memory: 49152 | int dot: 1 | matrix cores: NV_coopmat2
| model                          |       size |     params | backend    | ngl | fa |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | -: | --------------: | -------------------: |
| llama 8B Q4_K - Medium         |   4.58 GiB |     8.03 B | Vulkan     |  99 |  1 |           tg128 |        197.93 ± 1.24 |
| llama 8B Q6_K                  |   6.14 GiB |     8.03 B | Vulkan     |  99 |  1 |           tg128 |        172.73 ± 2.35 |
| qwen2 14B Q4_K - Medium        |   8.37 GiB |    14.77 B | Vulkan     |  99 |  1 |           tg128 |        110.35 ± 1.28 |
| llama 1B Q2_K - Medium         | 546.50 MiB |     1.24 B | Vulkan     |  99 |  1 |           tg128 |       601.99 ± 13.87 |
| llama 1B Q3_K - Small          | 604.50 MiB |     1.24 B | Vulkan     |  99 |  1 |           tg128 |        592.46 ± 6.37 |
| llama 3B Q5_K - Medium         |   2.16 GiB |     3.21 B | Vulkan     |  99 |  1 |           tg128 |        303.06 ± 1.98 |
| qwen3moe 30B.A3B Q2_K - Medium |  10.15 GiB |    30.53 B | Vulkan     |  99 |  1 |           tg128 |        213.76 ± 3.46 |
| qwen2 7B Q2_K - Medium         |   2.80 GiB |     7.62 B | Vulkan     |  99 |  1 |           tg128 |        198.53 ± 2.93 |
| deepseek2 16B Q4_K - Medium    |   9.65 GiB |    15.71 B | Vulkan     |  99 |  1 |           tg128 |        263.24 ± 3.61 |
| gpt-oss ?B MXFP4 MoE           |  11.27 GiB |    20.91 B | Vulkan     |  99 |  1 |           tg128 |        230.77 ± 5.15 |
| phi3 3B Q4_K - Medium          |   2.23 GiB |     3.82 B | Vulkan     |  99 |  1 |           tg128 |        301.65 ± 3.60 |
| llama 7B Q4_0                  |   3.56 GiB |     6.74 B | Vulkan     |  99 |  1 |           tg128 |        221.09 ± 1.52 |
| llama 3B Q8_0                  |   3.18 GiB |     3.21 B | Vulkan     |  99 |  1 |           tg128 |        271.62 ± 1.58 |
| llama ?B Q4_K - Medium         |  12.42 GiB |    22.24 B | Vulkan     |  99 |  1 |           tg128 |         80.44 ± 0.23 |
| deci 70B Q4_K - Small          |  26.66 GiB |    49.87 B | Vulkan     |  99 |  1 |           tg128 |         44.49 ± 0.20 |

5090 after

ggml_vulkan: 0 = NVIDIA GeForce RTX 5090 (NVIDIA) | uma: 0 | fp16: 1 | bf16: 1 | warp size: 32 | shared memory: 49152 | int dot: 1 | matrix cores: NV_coopmat2
| model                          |       size |     params | backend    | ngl | fa |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | -: | --------------: | -------------------: |
| llama 8B Q4_K - Medium         |   4.58 GiB |     8.03 B | Vulkan     |  99 |  1 |           tg128 |        201.15 ± 1.10 |
| llama 8B Q6_K                  |   6.14 GiB |     8.03 B | Vulkan     |  99 |  1 |           tg128 |        174.76 ± 4.70 |
| qwen2 14B Q4_K - Medium        |   8.37 GiB |    14.77 B | Vulkan     |  99 |  1 |           tg128 |        111.77 ± 1.92 |
| llama 1B Q2_K - Medium         | 546.50 MiB |     1.24 B | Vulkan     |  99 |  1 |           tg128 |        634.98 ± 8.88 |
| llama 1B Q3_K - Small          | 604.50 MiB |     1.24 B | Vulkan     |  99 |  1 |           tg128 |        613.19 ± 5.54 |
| llama 3B Q5_K - Medium         |   2.16 GiB |     3.21 B | Vulkan     |  99 |  1 |           tg128 |        316.52 ± 2.02 |
| qwen3moe 30B.A3B Q2_K - Medium |  10.15 GiB |    30.53 B | Vulkan     |  99 |  1 |           tg128 |        226.58 ± 3.22 |
| qwen2 7B Q2_K - Medium         |   2.80 GiB |     7.62 B | Vulkan     |  99 |  1 |           tg128 |        200.38 ± 3.24 |
| deepseek2 16B Q4_K - Medium    |   9.65 GiB |    15.71 B | Vulkan     |  99 |  1 |           tg128 |        280.54 ± 4.22 |
| gpt-oss ?B MXFP4 MoE           |  11.27 GiB |    20.91 B | Vulkan     |  99 |  1 |           tg128 |        238.21 ± 4.31 |
| phi3 3B Q4_K - Medium          |   2.23 GiB |     3.82 B | Vulkan     |  99 |  1 |           tg128 |        310.83 ± 9.35 |
| llama 7B Q4_0                  |   3.56 GiB |     6.74 B | Vulkan     |  99 |  1 |           tg128 |        228.22 ± 1.34 |
| llama 3B Q8_0                  |   3.18 GiB |     3.21 B | Vulkan     |  99 |  1 |           tg128 |        280.84 ± 2.44 |
| llama ?B Q4_K - Medium         |  12.42 GiB |    22.24 B | Vulkan     |  99 |  1 |           tg128 |         83.38 ± 0.39 |
| deci 70B Q4_K - Small          |  26.66 GiB |    49.87 B | Vulkan     |  99 |  1 |           tg128 |         44.89 ± 0.44 |

4070 before

Z:\github\jeffbolznv\llama.cpp\build\bin\RelWithDebInfo>llama-bench.exe -fa 1 -n 128 -p 0 -r 10 --prio 1 -m c:\models\DeepSeek-R1-Distill-Llama-8B-Q4_K_M.gguf -m c:\models\DeepSeek-R1-Distill-Llama-8B-Q6_K.gguf -m c:\models\DeepSeek-R1-Distill-Qwen-14B-Q4_K_M.gguf -m c:\models\Llama-3.2-1B.Q2_K.gguf -m c:\models\Llama-3.2-1B.Q3_K_S.gguf -m c:\models\llama-3.2-3b-instruct-q5_k_m.gguf -m c:\models\Qwen_Qwen3-30B-A3B-Q2_K.gguf -m c:\models\Qwen2.5-7B-Instruct-1M-Q2_K.gguf  -m c:\models\\deepseek-v2-lite-safetensors\deepseek-v2-lite-Q4_K_M.gguf -m c:\models\gpt-oss-20b-mxfp4.gguf -m c:\models\Phi-3-mini-4k-instruct-q4.gguf -m c:\models\llama-2-7b.Q4_0.gguf -m c:\models\llama-3.2-3b-instruct-q8_0.gguf -m c:\models\Mistral-22B-v0.2-Q4_K_M.gguf -m c:\models\nvidia_Llama-3_3-Nemotron-Super-49B-v1_5-Q4_K_S.gguf
ggml_vulkan: Found 1 Vulkan devices:
ggml_vulkan: 0 = NVIDIA GeForce RTX 4070 (NVIDIA) | uma: 0 | fp16: 1 | bf16: 1 | warp size: 32 | shared memory: 49152 | int dot: 1 | matrix cores: NV_coopmat2
| model                          |       size |     params | backend    | ngl | fa |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | -: | --------------: | -------------------: |
| llama 8B Q4_K - Medium         |   4.58 GiB |     8.03 B | Vulkan     |  99 |  1 |           tg128 |         87.45 ± 0.13 |
| llama 8B Q6_K                  |   6.14 GiB |     8.03 B | Vulkan     |  99 |  1 |           tg128 |         68.77 ± 0.11 |
| qwen2 14B Q4_K - Medium        |   8.37 GiB |    14.77 B | Vulkan     |  99 |  1 |           tg128 |         47.64 ± 0.04 |
| llama 1B Q2_K - Medium         | 546.50 MiB |     1.24 B | Vulkan     |  99 |  1 |           tg128 |        399.82 ± 2.91 |
| llama 1B Q3_K - Small          | 604.50 MiB |     1.24 B | Vulkan     |  99 |  1 |           tg128 |        368.13 ± 1.49 |
| llama 3B Q5_K - Medium         |   2.16 GiB |     3.21 B | Vulkan     |  99 |  1 |           tg128 |        156.47 ± 0.27 |
| qwen3moe 30B.A3B Q2_K - Medium |  10.15 GiB |    30.53 B | Vulkan     |  99 |  1 |           tg128 |        145.02 ± 4.27 |
| qwen2 7B Q2_K - Medium         |   2.80 GiB |     7.62 B | Vulkan     |  99 |  1 |           tg128 |        101.39 ± 1.72 |
| deepseek2 16B Q4_K - Medium    |   9.65 GiB |    15.71 B | Vulkan     |  99 |  1 |           tg128 |        169.05 ± 4.62 |
| gpt-oss ?B MXFP4 MoE           |  11.27 GiB |    20.91 B | Vulkan     |  99 |  1 |           tg128 |        123.79 ± 1.34 |
| phi3 3B Q4_K - Medium          |   2.23 GiB |     3.82 B | Vulkan     |  99 |  1 |           tg128 |        150.62 ± 0.27 |
| llama 7B Q4_0                  |   3.56 GiB |     6.74 B | Vulkan     |  99 |  1 |           tg128 |        102.17 ± 0.23 |
| llama 3B Q8_0                  |   3.18 GiB |     3.21 B | Vulkan     |  99 |  1 |           tg128 |        117.61 ± 0.37 |

4070 after

ggml_vulkan: 0 = NVIDIA GeForce RTX 4070 (NVIDIA) | uma: 0 | fp16: 1 | bf16: 1 | warp size: 32 | shared memory: 49152 | int dot: 1 | matrix cores: NV_coopmat2
| model                          |       size |     params | backend    | ngl | fa |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | -: | --------------: | -------------------: |
| llama 8B Q4_K - Medium         |   4.58 GiB |     8.03 B | Vulkan     |  99 |  1 |           tg128 |         88.23 ± 0.18 |
| llama 8B Q6_K                  |   6.14 GiB |     8.03 B | Vulkan     |  99 |  1 |           tg128 |         68.99 ± 0.52 |
| qwen2 14B Q4_K - Medium        |   8.37 GiB |    14.77 B | Vulkan     |  99 |  1 |           tg128 |         48.26 ± 0.06 |
| llama 1B Q2_K - Medium         | 546.50 MiB |     1.24 B | Vulkan     |  99 |  1 |           tg128 |        405.26 ± 2.08 |
| llama 1B Q3_K - Small          | 604.50 MiB |     1.24 B | Vulkan     |  99 |  1 |           tg128 |        369.52 ± 6.98 |
| llama 3B Q5_K - Medium         |   2.16 GiB |     3.21 B | Vulkan     |  99 |  1 |           tg128 |        158.72 ± 0.26 |
| qwen3moe 30B.A3B Q2_K - Medium |  10.15 GiB |    30.53 B | Vulkan     |  99 |  1 |           tg128 |        150.36 ± 5.59 |
| qwen2 7B Q2_K - Medium         |   2.80 GiB |     7.62 B | Vulkan     |  99 |  1 |           tg128 |        102.80 ± 1.29 |
| deepseek2 16B Q4_K - Medium    |   9.65 GiB |    15.71 B | Vulkan     |  99 |  1 |           tg128 |        173.97 ± 6.52 |
| gpt-oss ?B MXFP4 MoE           |  11.27 GiB |    20.91 B | Vulkan     |  99 |  1 |           tg128 |        124.76 ± 3.59 |
| phi3 3B Q4_K - Medium          |   2.23 GiB |     3.82 B | Vulkan     |  99 |  1 |           tg128 |        152.67 ± 0.48 |
| llama 7B Q4_0                  |   3.56 GiB |     6.74 B | Vulkan     |  99 |  1 |           tg128 |        103.38 ± 0.43 |
| llama 3B Q8_0                  |   3.18 GiB |     3.21 B | Vulkan     |  99 |  1 |           tg128 |        119.54 ± 0.21 |

@jeffbolznv jeffbolznv requested a review from 0cc4m as a code owner August 21, 2025 20:14
Track a list of nodes that need synchronization, and only sync if the new node
depends on them (or overwrites them). This allows some overlap which can
improve performance, and centralizes a big chunk of the synchronization logic.

The remaining synchronization logic involves writes to memory other than the
nodes, e.g. for dequantization or split_k. Each of these allocations has a bool
indicating whether they were in use and need to be synced. This should be
checked before they are written to, and set to true after they are done being
consumed.
@github-actions github-actions bot added Vulkan Issues specific to the Vulkan backend ggml changes relating to the ggml tensor library for machine learning labels Aug 21, 2025
@slaren
Copy link
Member

slaren commented Aug 21, 2025

I don't know if this is a problem with this implementation specifically, but a common issue when doing this is that you have to consider that not every dependency is represented in the graph. A classic example was the ggml_cpy operation to update the KV cache. The attention computation depends on this operation, but this dependency is not represented directly in the graph. I believe ggml_set_rows has the same issue.

@jeffbolznv
Copy link
Collaborator Author

I'm not specifically looking at the graph, I'm looking at the ranges of memory used by each tensor (destination and sources) as they're processed in order, and that should include things like set_rows and cpy. I think it all holds together and I did some testing with real models to check that they generate identical results before and after.

BTW, while working on this I saw some cases where there are "false" dependencies due to memory reuse. i.e. things that wouldn't otherwise need synchronize do need it because some result from a couple nodes ago is being overwritten. So there's potential for even more performance if the allocator took this into account.

@0cc4m
Copy link
Collaborator

0cc4m commented Aug 23, 2025

ggml_vulkan: 0 = NVIDIA GeForce RTX 3090 (NVIDIA) | uma: 0 | fp16: 1 | bf16: 1 | warp size: 32 | shared memory: 49152 | int dot: 1 | matrix cores: NV_coopmat2

model size params backend ngl fa test t/s (before) t/s (after) diff
llama 7B Q4_0 3.56 GiB 6.74 B Vulkan 99 0 tg128 140.84 ± 1.14 145.06 ± 0.41 +3.0%
llama 7B Q4_0 3.56 GiB 6.74 B Vulkan 99 1 tg128 142.41 ± 0.14 146.29 ± 0.23 +2.7%
llama 8B Q4_0 4.33 GiB 8.03 B Vulkan 99 0 tg128 127.63 ± 0.83 131.63 ± 0.35 +3.1%
llama 8B Q4_0 4.33 GiB 8.03 B Vulkan 99 1 tg128 129.84 ± 0.30 133.99 ± 0.25 +3.2%
llama 8B Q4_1 4.77 GiB 8.03 B Vulkan 99 0 tg128 118.34 ± 2.92 122.61 ± 0.36 +3.6%
llama 8B Q4_1 4.77 GiB 8.03 B Vulkan 99 1 tg128 122.11 ± 0.18 125.35 ± 0.17 +2.7%
llama 8B Q8_0 7.95 GiB 8.03 B Vulkan 99 0 tg128 87.12 ± 1.86 88.69 ± 2.00 +1.8%
llama 8B Q8_0 7.95 GiB 8.03 B Vulkan 99 1 tg128 89.24 ± 0.22 90.72 ± 0.12 +1.7%

ggml_vulkan: 0 = AMD Radeon (TM) Pro VII (RADV VEGA20) (radv) | uma: 0 | fp16: 1 | bf16: 0 | warp size: 64 | shared memory: 65536 | int dot: 1 | matrix cores: none

model size params backend ngl fa test t/s (before) t/s (after) diff
llama 7B Q4_0 3.56 GiB 6.74 B Vulkan 99 0 tg128 76.59 ± 0.12 78.44 ± 0.16 +2.4%
llama 7B Q4_0 3.56 GiB 6.74 B Vulkan 99 1 tg128 77.90 ± 0.21 79.54 ± 0.26 +2.1%
llama 8B Q4_0 4.33 GiB 8.03 B Vulkan 99 0 tg128 70.15 ± 0.20 71.13 ± 0.19 +1.4%
llama 8B Q4_0 4.33 GiB 8.03 B Vulkan 99 1 tg128 66.92 ± 0.26 67.58 ± 0.09 +1.0%
llama 8B Q4_1 4.77 GiB 8.03 B Vulkan 99 0 tg128 61.11 ± 0.10 63.11 ± 0.10 +3.3%
llama 8B Q4_1 4.77 GiB 8.03 B Vulkan 99 1 tg128 58.65 ± 0.09 60.48 ± 0.06 +3.1%
llama 8B Q8_0 7.95 GiB 8.03 B Vulkan 99 0 tg128 58.85 ± 0.10 58.99 ± 0.14 +0.2%
llama 8B Q8_0 7.95 GiB 8.03 B Vulkan 99 1 tg128 56.59 ± 0.05 56.31 ± 0.26 -0.5%

ggml_vulkan: 0 = Intel(R) Arc(tm) A770 Graphics (DG2) (Intel open-source Mesa driver) | uma: 0 | fp16: 1 | bf16: 1 | warp size: 32 | shared memory: 65536 | int dot: 1 | matrix cores: none

model size params backend ngl fa test t/s (before) t/s (after) diff
llama 7B Q4_0 3.56 GiB 6.74 B Vulkan 99 0 tg128 43.35 ± 0.74 46.85 ± 0.96 +8.1%
llama 7B Q4_0 3.56 GiB 6.74 B Vulkan 99 1 tg128 44.48 ± 0.13 48.14 ± 0.02 +8.2%
llama 8B Q4_0 4.33 GiB 8.03 B Vulkan 99 0 tg128 37.76 ± 0.67 38.26 ± 0.57 +1.3%
llama 8B Q4_0 4.33 GiB 8.03 B Vulkan 99 1 tg128 28.20 ± 0.06 28.47 ± 0.02 +1.0%
llama 8B Q4_1 4.77 GiB 8.03 B Vulkan 99 0 tg128 33.38 ± 0.09 34.39 ± 0.01 +3.0%
llama 8B Q4_1 4.77 GiB 8.03 B Vulkan 99 1 tg128 25.70 ± 0.02 26.44 ± 0.02 +2.9%
llama 8B Q8_0 7.95 GiB 8.03 B Vulkan 99 0 tg128 10.86 ± 0.02 11.00 ± 0.03 +1.3%
llama 8B Q8_0 7.95 GiB 8.03 B Vulkan 99 1 tg128 9.93 ± 0.02 10.05 ± 0.00 +1.2%

I don't see any issues with the model outputs after this change either.

@0cc4m 0cc4m merged commit 289bf41 into ggml-org:master Aug 23, 2025
48 checks passed
jeffbolznv added a commit to jeffbolznv/llama.cpp that referenced this pull request Aug 23, 2025
jeffbolznv added a commit that referenced this pull request Aug 23, 2025
…le SMs (#15281)

* vulkan: optimize rms_norm, and allow the work to spread across multiple SMs

There are really two parts to this change:
(1) Some optimizations similar to what we have in soft_max, to unroll with
different numbers of iterations.
(2) A fusion optimization where we detect add followed by rms_norm, and make
the add shader atomically accumulate the values^2 into memory. Then the
rms_norm shader can just load that sum. This allows the rms_norm to be
parallelized across multiple workgroups, it just becomes a simple per-element
multiply.

The fusion optimization is currently only applied when the rms_norm is on a
single vector. This previously always ran on a single SM. It could apply more
broadly, but when there are other dimensions the work can already spread across
SMs, and there would be some complexity to tracking multiple atomic sums.

* Change add+rms_norm optimization to write out an array of partial sums
rather than using atomic add, to make it deterministic. The rms_norm
shader fetches a subgroup's worth in parallel and uses subgroupAdd to
add them up.

* complete rebase against fused adds - multi_add shader can also compute partial sums

* fix validation errors

* disable add_rms_fusion for Intel due to possible driver bug

* resolve against #15489, sync after clearing partial sums
FlorianZimmer pushed a commit to FlorianZimmer/llama.cpp that referenced this pull request Aug 25, 2025
…le SMs (ggml-org#15281)

* vulkan: optimize rms_norm, and allow the work to spread across multiple SMs

There are really two parts to this change:
(1) Some optimizations similar to what we have in soft_max, to unroll with
different numbers of iterations.
(2) A fusion optimization where we detect add followed by rms_norm, and make
the add shader atomically accumulate the values^2 into memory. Then the
rms_norm shader can just load that sum. This allows the rms_norm to be
parallelized across multiple workgroups, it just becomes a simple per-element
multiply.

The fusion optimization is currently only applied when the rms_norm is on a
single vector. This previously always ran on a single SM. It could apply more
broadly, but when there are other dimensions the work can already spread across
SMs, and there would be some complexity to tracking multiple atomic sums.

* Change add+rms_norm optimization to write out an array of partial sums
rather than using atomic add, to make it deterministic. The rms_norm
shader fetches a subgroup's worth in parallel and uses subgroupAdd to
add them up.

* complete rebase against fused adds - multi_add shader can also compute partial sums

* fix validation errors

* disable add_rms_fusion for Intel due to possible driver bug

* resolve against ggml-org#15489, sync after clearing partial sums
qnixsynapse pushed a commit to menloresearch/llama.cpp that referenced this pull request Aug 25, 2025
…gml-org#15489)

Track a list of nodes that need synchronization, and only sync if the new node
depends on them (or overwrites them). This allows some overlap which can
improve performance, and centralizes a big chunk of the synchronization logic.

The remaining synchronization logic involves writes to memory other than the
nodes, e.g. for dequantization or split_k. Each of these allocations has a bool
indicating whether they were in use and need to be synced. This should be
checked before they are written to, and set to true after they are done being
consumed.
qnixsynapse pushed a commit to menloresearch/llama.cpp that referenced this pull request Aug 25, 2025
…le SMs (ggml-org#15281)

* vulkan: optimize rms_norm, and allow the work to spread across multiple SMs

There are really two parts to this change:
(1) Some optimizations similar to what we have in soft_max, to unroll with
different numbers of iterations.
(2) A fusion optimization where we detect add followed by rms_norm, and make
the add shader atomically accumulate the values^2 into memory. Then the
rms_norm shader can just load that sum. This allows the rms_norm to be
parallelized across multiple workgroups, it just becomes a simple per-element
multiply.

The fusion optimization is currently only applied when the rms_norm is on a
single vector. This previously always ran on a single SM. It could apply more
broadly, but when there are other dimensions the work can already spread across
SMs, and there would be some complexity to tracking multiple atomic sums.

* Change add+rms_norm optimization to write out an array of partial sums
rather than using atomic add, to make it deterministic. The rms_norm
shader fetches a subgroup's worth in parallel and uses subgroupAdd to
add them up.

* complete rebase against fused adds - multi_add shader can also compute partial sums

* fix validation errors

* disable add_rms_fusion for Intel due to possible driver bug

* resolve against ggml-org#15489, sync after clearing partial sums
jeffbolznv added a commit to jeffbolznv/llama.cpp that referenced this pull request Sep 7, 2025
Add a backend proc to allow the backend to modify the graph. The
vulkan implementation looks at which nodes depend on each other
and greedily reorders them to group together nodes that don't
depend on each other. It only reorders the nodes, doesn't change
the contents of any of them.

With ggml-org#15489, this reduces the number of synchroniations needed.
jeffbolznv added a commit to jeffbolznv/llama.cpp that referenced this pull request Sep 7, 2025
Add a backend proc to allow the backend to modify the graph. The
vulkan implementation looks at which nodes depend on each other
and greedily reorders them to group together nodes that don't
depend on each other. It only reorders the nodes, doesn't change
the contents of any of them.

With ggml-org#15489, this reduces the number of synchronizations needed.
jeffbolznv added a commit to jeffbolznv/llama.cpp that referenced this pull request Sep 7, 2025
Add a backend proc to allow the backend to modify the graph. The
vulkan implementation looks at which nodes depend on each other
and greedily reorders them to group together nodes that don't
depend on each other. It only reorders the nodes, doesn't change
the contents of any of them.

With ggml-org#15489, this reduces the number of synchronizations needed.
jeffbolznv added a commit to jeffbolznv/llama.cpp that referenced this pull request Sep 7, 2025
Add a backend proc to allow the backend to modify the graph. The
vulkan implementation looks at which nodes depend on each other
and greedily reorders them to group together nodes that don't
depend on each other. It only reorders the nodes, doesn't change
the contents of any of them.

With ggml-org#15489, this reduces the number of synchronizations needed.
taronaeo pushed a commit that referenced this pull request Sep 8, 2025
* vulkan: sort graph to allow more parallel execution

Add a backend proc to allow the backend to modify the graph. The
vulkan implementation looks at which nodes depend on each other
and greedily reorders them to group together nodes that don't
depend on each other. It only reorders the nodes, doesn't change
the contents of any of them.

With #15489, this reduces the number of synchronizations needed.

* call optimize_graph per-split
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ggml changes relating to the ggml tensor library for machine learning Vulkan Issues specific to the Vulkan backend

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants