Skip to content

Conversation

@jeffbolznv
Copy link
Collaborator

q4_k and q5_k had a lot of redundant global loads where the same 16B of scale information is repeatedly loaded and decoded during each loop iteration. This change restructures the loops to more explicitly iterate over whole blocks in the outer loop (with unrolled inner loop) and to copy/decode the scale data into shared memory once at the start of each outer loop. The copy is pipelined so the scale load from global memory is relatively cheap.

This improves q4_k/q5_k model prompt processing performance by around 5-7%. I briefly tried applying this to q6_k and q4_0, and it didn't help for q6_k and hurt for q4_0.

The big "else" path in mul_mm_cm2.comp that had all the clamped/unclamped variants isn't used as often as it originally was (e.g. due to the padded_N change), so I trimmed it down to offset some of the new complexity of the semi-manual loop unrolling.

Perf measurements on RTX 4070 and 3070:

RTX 4070:

before:
  MUL_MAT(type_a=q4_K,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                      834 runs -  1201.38 us/run -  60.13 GFLOP/run -  50.05 TFLOPS
  MUL_MAT(type_a=q5_K,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                      760 runs -  1318.26 us/run -  60.13 GFLOP/run -  45.61 TFLOPS

Z:\github\jeffbolznv\llama.cpp\build\bin\RelWithDebInfo>llama-bench.exe -m C:\models\Phi-3-mini-4k-instruct-q4.gguf -p 128,256,512 -n 0 -fa 1
ggml_vulkan: Found 1 Vulkan devices:
ggml_vulkan: 0 = NVIDIA GeForce RTX 4070 (NVIDIA) | uma: 0 | fp16: 1 | warp size: 32 | shared memory: 49152 | int dot: 1 | matrix cores: NV_coopmat2
| model                          |       size |     params | backend    | ngl | fa |          test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | -: | ------------: | -------------------: |
| phi3 3B Q4_K - Medium          |   2.23 GiB |     3.82 B | Vulkan     |  99 |  1 |         pp128 |     4108.08 ± 855.66 |
| phi3 3B Q4_K - Medium          |   2.23 GiB |     3.82 B | Vulkan     |  99 |  1 |         pp256 |     5450.02 ± 121.23 |
| phi3 3B Q4_K - Medium          |   2.23 GiB |     3.82 B | Vulkan     |  99 |  1 |         pp512 |      5867.83 ± 71.75 |

Z:\github\jeffbolznv\llama.cpp\build\bin\RelWithDebInfo>llama-bench.exe -m C:\models\meta-llama-3-8b-instruct.Q4_K_M.gguf -p 128,256,512 -n 0 -fa 1
ggml_vulkan: Found 1 Vulkan devices:
ggml_vulkan: 0 = NVIDIA GeForce RTX 4070 (NVIDIA) | uma: 0 | fp16: 1 | warp size: 32 | shared memory: 49152 | int dot: 1 | matrix cores: NV_coopmat2
| model                          |       size |     params | backend    | ngl | fa |          test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | -: | ------------: | -------------------: |
| llama 8B Q4_K - Medium         |   4.58 GiB |     8.03 B | Vulkan     |  99 |  1 |         pp128 |      2819.20 ± 37.09 |
| llama 8B Q4_K - Medium         |   4.58 GiB |     8.03 B | Vulkan     |  99 |  1 |         pp256 |      3311.65 ± 15.69 |
| llama 8B Q4_K - Medium         |   4.58 GiB |     8.03 B | Vulkan     |  99 |  1 |         pp512 |       3380.34 ± 7.45 |

ggml_vulkan: 0 = NVIDIA GeForce RTX 4070 (NVIDIA) | uma: 0 | fp16: 1 | warp size: 32 | shared memory: 49152 | int dot: 1 | matrix cores: NV_coopmat2
| model                          |       size |     params | backend    | ngl | fa |          test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | -: | ------------: | -------------------: |
| llama 8B Q4_K - Medium         |   4.58 GiB |     8.03 B | Vulkan     |  99 |  1 |       pp16384 |       2394.74 ± 0.00 |

after:
  MUL_MAT(type_a=q4_K,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                      924 runs -  1084.26 us/run -  60.13 GFLOP/run -  55.46 TFLOPS
  MUL_MAT(type_a=q5_K,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                      856 runs -  1169.77 us/run -  60.13 GFLOP/run -  51.40 TFLOPS

Z:\github\jeffbolznv\llama.cpp\build\bin\RelWithDebInfo>llama-bench.exe -m C:\models\Phi-3-mini-4k-instruct-q4.gguf -p 128,256,512 -n 0 -fa 1
ggml_vulkan: Found 1 Vulkan devices:
ggml_vulkan: 0 = NVIDIA GeForce RTX 4070 (NVIDIA) | uma: 0 | fp16: 1 | warp size: 32 | shared memory: 49152 | int dot: 1 | matrix cores: NV_coopmat2
| model                          |       size |     params | backend    | ngl | fa |          test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | -: | ------------: | -------------------: |
| phi3 3B Q4_K - Medium          |   2.23 GiB |     3.82 B | Vulkan     |  99 |  1 |         pp128 |      4739.46 ± 36.16 |
| phi3 3B Q4_K - Medium          |   2.23 GiB |     3.82 B | Vulkan     |  99 |  1 |         pp256 |     5708.61 ± 194.14 |
| phi3 3B Q4_K - Medium          |   2.23 GiB |     3.82 B | Vulkan     |  99 |  1 |         pp512 |      6243.01 ± 58.32 |

Z:\github\jeffbolznv\llama.cpp\build\bin\RelWithDebInfo>llama-bench.exe -m C:\models\meta-llama-3-8b-instruct.Q4_K_M.gguf -p 128,256,512 -n 0 -fa 1
ggml_vulkan: Found 1 Vulkan devices:
ggml_vulkan: 0 = NVIDIA GeForce RTX 4070 (NVIDIA) | uma: 0 | fp16: 1 | warp size: 32 | shared memory: 49152 | int dot: 1 | matrix cores: NV_coopmat2
| model                          |       size |     params | backend    | ngl | fa |          test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | -: | ------------: | -------------------: |
| llama 8B Q4_K - Medium         |   4.58 GiB |     8.03 B | Vulkan     |  99 |  1 |         pp128 |      3005.27 ± 34.27 |
| llama 8B Q4_K - Medium         |   4.58 GiB |     8.03 B | Vulkan     |  99 |  1 |         pp256 |      3476.78 ± 21.73 |
| llama 8B Q4_K - Medium         |   4.58 GiB |     8.03 B | Vulkan     |  99 |  1 |         pp512 |      3626.96 ± 29.06 |

Z:\github\jeffbolznv\llama.cpp\build\bin\RelWithDebInfo>llama-bench.exe -m C:\models\meta-llama-3-8b-instruct.Q4_K_M.gguf -p 16384 -n 0 -fa 1 --repetitions 1
ggml_vulkan: Found 1 Vulkan devices:
ggml_vulkan: 0 = NVIDIA GeForce RTX 4070 (NVIDIA) | uma: 0 | fp16: 1 | warp size: 32 | shared memory: 49152 | int dot: 1 | matrix cores: NV_coopmat2
| model                          |       size |     params | backend    | ngl | fa |          test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | -: | ------------: | -------------------: |
| llama 8B Q4_K - Medium         |   4.58 GiB |     8.03 B | Vulkan     |  99 |  1 |       pp16384 |       2528.98 ± 0.00 |


RTX 3070:

before:
  MUL_MAT(type_a=q4_K,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                      616 runs -  1625.87 us/run -  60.13 GFLOP/run -  36.98 TFLOPS
  MUL_MAT(type_a=q5_K,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                      562 runs -  1780.18 us/run -  60.13 GFLOP/run -  33.78 TFLOPS

Z:\github\jeffbolznv\llama.cpp\build\bin\RelWithDebInfo>llama-bench.exe -m C:\models\Phi-3-mini-4k-instruct-q4.gguf -p 128,256,512 -n 0 -fa 1
ggml_vulkan: Found 1 Vulkan devices:
ggml_vulkan: 0 = NVIDIA GeForce RTX 3070 (NVIDIA) | uma: 0 | fp16: 1 | warp size: 32 | shared memory: 49152 | int dot: 1 | matrix cores: NV_coopmat2
| model                          |       size |     params | backend    | ngl | fa |          test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | -: | ------------: | -------------------: |
| phi3 3B Q4_K - Medium          |   2.23 GiB |     3.82 B | Vulkan     |  99 |  1 |         pp128 |      3189.86 ± 55.47 |
| phi3 3B Q4_K - Medium          |   2.23 GiB |     3.82 B | Vulkan     |  99 |  1 |         pp256 |      3778.46 ± 20.94 |
| phi3 3B Q4_K - Medium          |   2.23 GiB |     3.82 B | Vulkan     |  99 |  1 |         pp512 |       3986.65 ± 4.41 |

Z:\github\jeffbolznv\llama.cpp\build\bin\RelWithDebInfo>llama-bench.exe -m C:\models\meta-llama-3-8b-instruct.Q4_K_M.gguf -p 128,256,512 -n 0 -fa 1
ggml_vulkan: Found 1 Vulkan devices:
ggml_vulkan: 0 = NVIDIA GeForce RTX 3070 (NVIDIA) | uma: 0 | fp16: 1 | warp size: 32 | shared memory: 49152 | int dot: 1 | matrix cores: NV_coopmat2
| model                          |       size |     params | backend    | ngl | fa |          test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | -: | ------------: | -------------------: |
| llama 8B Q4_K - Medium         |   4.58 GiB |     8.03 B | Vulkan     |  99 |  1 |         pp128 |      2026.14 ± 32.03 |
| llama 8B Q4_K - Medium         |   4.58 GiB |     8.03 B | Vulkan     |  99 |  1 |         pp256 |       2354.75 ± 4.51 |
| llama 8B Q4_K - Medium         |   4.58 GiB |     8.03 B | Vulkan     |  99 |  1 |         pp512 |      2407.21 ± 17.41 |

Z:\github\jeffbolznv\llama.cpp\build\bin\RelWithDebInfo>llama-bench.exe -m C:\models\meta-llama-3-8b-instruct.Q4_K_M.gguf -p 16384 -n 0 -fa 1 --repetitions 1
ggml_vulkan: Found 1 Vulkan devices:
ggml_vulkan: 0 = NVIDIA GeForce RTX 3070 (NVIDIA) | uma: 0 | fp16: 1 | warp size: 32 | shared memory: 49152 | int dot: 1 | matrix cores: NV_coopmat2
| model                          |       size |     params | backend    | ngl | fa |          test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | -: | ------------: | -------------------: |
| llama 8B Q4_K - Medium         |   4.58 GiB |     8.03 B | Vulkan     |  99 |  1 |       pp16384 |       1746.93 ± 0.00 |

after:
  MUL_MAT(type_a=q4_K,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                      682 runs -  1467.96 us/run -  60.13 GFLOP/run -  40.96 TFLOPS
  MUL_MAT(type_a=q5_K,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                      608 runs -  1647.96 us/run -  60.13 GFLOP/run -  36.49 TFLOPS

Z:\github\jeffbolznv\llama.cpp\build\bin\RelWithDebInfo>llama-bench.exe -m C:\models\Phi-3-mini-4k-instruct-q4.gguf -p 128,256,512 -n 0 -fa 1
ggml_vulkan: Found 1 Vulkan devices:
ggml_vulkan: 0 = NVIDIA GeForce RTX 3070 (NVIDIA) | uma: 0 | fp16: 1 | warp size: 32 | shared memory: 49152 | int dot: 1 | matrix cores: NV_coopmat2
| model                          |       size |     params | backend    | ngl | fa |          test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | -: | ------------: | -------------------: |
| phi3 3B Q4_K - Medium          |   2.23 GiB |     3.82 B | Vulkan     |  99 |  1 |         pp128 |      3353.68 ± 63.07 |
| phi3 3B Q4_K - Medium          |   2.23 GiB |     3.82 B | Vulkan     |  99 |  1 |         pp256 |      3983.36 ± 34.95 |
| phi3 3B Q4_K - Medium          |   2.23 GiB |     3.82 B | Vulkan     |  99 |  1 |         pp512 |      4211.02 ± 17.09 |

Z:\github\jeffbolznv\llama.cpp\build\bin\RelWithDebInfo>llama-bench.exe -m C:\models\meta-llama-3-8b-instruct.Q4_K_M.gguf -p 128,256,512 -n 0 -fa 1
ggml_vulkan: Found 1 Vulkan devices:
ggml_vulkan: 0 = NVIDIA GeForce RTX 3070 (NVIDIA) | uma: 0 | fp16: 1 | warp size: 32 | shared memory: 49152 | int dot: 1 | matrix cores: NV_coopmat2
| model                          |       size |     params | backend    | ngl | fa |          test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | -: | ------------: | -------------------: |
| llama 8B Q4_K - Medium         |   4.58 GiB |     8.03 B | Vulkan     |  99 |  1 |         pp128 |      2218.93 ± 17.03 |
| llama 8B Q4_K - Medium         |   4.58 GiB |     8.03 B | Vulkan     |  99 |  1 |         pp256 |      2484.39 ± 64.54 |
| llama 8B Q4_K - Medium         |   4.58 GiB |     8.03 B | Vulkan     |  99 |  1 |         pp512 |      2607.34 ± 19.53 |

Z:\github\jeffbolznv\llama.cpp\build\bin\RelWithDebInfo>llama-bench.exe -m C:\models\meta-llama-3-8b-instruct.Q4_K_M.gguf -p 16384 -n 0 -fa 1 --repetitions 1
ggml_vulkan: Found 1 Vulkan devices:
ggml_vulkan: 0 = NVIDIA GeForce RTX 3070 (NVIDIA) | uma: 0 | fp16: 1 | warp size: 32 | shared memory: 49152 | int dot: 1 | matrix cores: NV_coopmat2
| model                          |       size |     params | backend    | ngl | fa |          test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | -: | ------------: | -------------------: |
| llama 8B Q4_K - Medium         |   4.58 GiB |     8.03 B | Vulkan     |  99 |  1 |       pp16384 |       1861.15 ± 0.00 |

@jeffbolznv jeffbolznv requested a review from 0cc4m April 8, 2025 15:36
@github-actions github-actions bot added Vulkan Issues specific to the Vulkan backend ggml changes relating to the ggml tensor library for machine learning labels Apr 8, 2025
q4_k and q5_k had a lot of redundant global loads where the same 16B of
scale information is repeatedly loaded and decoded during each loop iteration.
This change restructures the loops to more explicitly iterate over whole
blocks in the outer loop (with unrolled inner loop) and to copy/decode the
scale data into shared memory once at the start of each outer loop. The copy
is pipelined so the scale load from global memory is relatively cheap.

This improves q4_k/q5_k model prompt processing performance by around 5-7%.
I briefly tried applying this to q6_k and q4_0, and it didn't help for q6_k
and hurt for q4_0.

The big "else" path in mul_mm_cm2.comp that had all the clamped/unclamped
variants isn't used as often as it originally was (e.g. due to the padded_N
change), so I trimmed it down to offset some of the new complexity of the
semi-manual loop unrolling.
@characharm
Copy link
Contributor

ggml_vulkan: 0 = Intel(R) Arc(TM) A770 Graphics (Intel Corporation) | uma: 0 | fp16: 1 | warp size: 32 | shared memory: 32768 | int dot: 1 | matrix cores: none

model size params backend ngl test t/s
qwen2 14B Q5_K - Medium 9.78 GiB 14.77 B Vulkan 99 pp128 138.26 ± 0.21
qwen2 14B Q5_K - Medium 9.78 GiB 14.77 B Vulkan 99 pp256 147.72 ± 2.03
qwen2 14B Q5_K - Medium 9.78 GiB 14.77 B Vulkan 99 pp512 165.60 ± 1.17

build: 8918306 (5075)

ggml_vulkan: 0 = Intel(R) Arc(TM) A770 Graphics (Intel Corporation) | uma: 0 | fp16: 1 | warp size: 32 | shared memory: 32768 | int dot: 1 | matrix cores: none

model size params backend ngl test t/s
qwen2 14B Q5_K - Medium 9.78 GiB 14.77 B Vulkan 99 pp128 136.43 ± 1.22
qwen2 14B Q5_K - Medium 9.78 GiB 14.77 B Vulkan 99 pp256 146.89 ± 1.16
qwen2 14B Q5_K - Medium 9.78 GiB 14.77 B Vulkan 99 pp512 165.55 ± 0.38

build: dab1f02 (5062)

@jeffbolznv
Copy link
Collaborator Author

ggml_vulkan: 0 = Intel(R) Arc(TM) A770 Graphics (Intel Corporation) | uma: 0 | fp16: 1 | warp size: 32 | shared memory: 32768 | int dot: 1 | matrix cores: none

This change only affects the NV_coopmat2 shader, so won't affect results on Intel.

@characharm
Copy link
Contributor

ggml_vulkan: 0 = Intel(R) Arc(TM) A770 Graphics (Intel Corporation) | uma: 0 | fp16: 1 | warp size: 32 | shared memory: 32768 | int dot: 1 | matrix cores: none

This change only affects the NV_coopmat2 shader, so won't affect results on Intel.

Oh, okay — and by the way, thanks for your work. :)

Copy link
Collaborator

@0cc4m 0cc4m left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@0cc4m 0cc4m merged commit 0090950 into ggml-org:master Apr 9, 2025
51 checks passed
colout pushed a commit to colout/llama.cpp that referenced this pull request Apr 21, 2025
…gml-org#12833)

q4_k and q5_k had a lot of redundant global loads where the same 16B of
scale information is repeatedly loaded and decoded during each loop iteration.
This change restructures the loops to more explicitly iterate over whole
blocks in the outer loop (with unrolled inner loop) and to copy/decode the
scale data into shared memory once at the start of each outer loop. The copy
is pipelined so the scale load from global memory is relatively cheap.

This improves q4_k/q5_k model prompt processing performance by around 5-7%.
I briefly tried applying this to q6_k and q4_0, and it didn't help for q6_k
and hurt for q4_0.

The big "else" path in mul_mm_cm2.comp that had all the clamped/unclamped
variants isn't used as often as it originally was (e.g. due to the padded_N
change), so I trimmed it down to offset some of the new complexity of the
semi-manual loop unrolling.
timwu pushed a commit to timwu/llama.cpp that referenced this pull request May 5, 2025
…gml-org#12833)

q4_k and q5_k had a lot of redundant global loads where the same 16B of
scale information is repeatedly loaded and decoded during each loop iteration.
This change restructures the loops to more explicitly iterate over whole
blocks in the outer loop (with unrolled inner loop) and to copy/decode the
scale data into shared memory once at the start of each outer loop. The copy
is pipelined so the scale load from global memory is relatively cheap.

This improves q4_k/q5_k model prompt processing performance by around 5-7%.
I briefly tried applying this to q6_k and q4_0, and it didn't help for q6_k
and hurt for q4_0.

The big "else" path in mul_mm_cm2.comp that had all the clamped/unclamped
variants isn't used as often as it originally was (e.g. due to the padded_N
change), so I trimmed it down to offset some of the new complexity of the
semi-manual loop unrolling.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ggml changes relating to the ggml tensor library for machine learning Vulkan Issues specific to the Vulkan backend

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants