Skip to content

Conversation

am17an
Copy link
Collaborator

@am17an am17an commented Oct 17, 2025

This PR adds support for fusing mmvq and mmvf with an optional gate and GLU. Currently it only supports SWIGLU and no bias, which is by far the most common pattern. Perf gains in TG of 4-9% for quantized models, lesser for fp models.

After #16130 and this PR ggml_can_fuse is too primitive to support fusion. What we want is a self-contained DAG with one exit point, where views are not used elsewhere in the graph. I will create a future PR for that

Performance on a 4090

Model Test t/s master t/s cuda_fuse_gate Speedup
deepseek2 16B Q4_0 tg32 205.77 218.02 1.06
deepseek2 16B Q4_0 tg64 202.49 215.07 1.06
deepseek2 16B Q4_0 tg128 200.51 212.70 1.06
qwen3 0.6B F16 tg32 311.52 325.05 1.04
qwen3 0.6B F16 tg64 305.72 318.91 1.04
qwen3 0.6B F16 tg128 303.35 315.34 1.04
qwen3 0.6B Q4_0 tg32 416.74 453.17 1.09
qwen3 0.6B Q4_0 tg64 407.95 444.68 1.09
qwen3 0.6B Q4_0 tg128 403.34 440.55 1.09
qwen3 8B F16 tg32 48.15 48.42 1.01
qwen3 8B F16 tg64 47.93 48.17 1.01
qwen3 8B F16 tg128 47.83 48.07 1.00
qwen3moe 30B.A3B Q4_0 tg32 154.49 160.81 1.04
qwen3moe 30B.A3B Q4_0 tg64 151.12 157.17 1.04
qwen3moe 30B.A3B Q4_0 tg128 149.57 155.33 1.04
gemma3 1B F16 tg32 234.06 241.52 1.03
gemma3 1B F16 tg64 231.35 238.68 1.03
gemma3 1B F16 tg128 230.10 237.11 1.03

@github-actions github-actions bot added testing Everything test related Nvidia GPU Issues specific to Nvidia GPUs ggml changes relating to the ggml tensor library for machine learning labels Oct 17, 2025
@JohannesGaessler
Copy link
Collaborator

Without having looked at this PR, consider also fusing the Q, K, and V matrix multiplications into a single, batched operation. It's not going to reduce I/O but it's going to reduce kernel launch overhead and tail effects.

}
};

struct test_fused_ffn_gate : public test_case {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it make sense to extend the test case for matrix multiplication instead?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is fine, because it has a switch for mul_mat_id as well

@am17an
Copy link
Collaborator Author

am17an commented Oct 17, 2025

Note that with quantized types and GEGLU operator, I was getting NMSE as high as 1e-2

@am17an
Copy link
Collaborator Author

am17an commented Oct 18, 2025

It looks like the newly added tests for f16 and f32 are failing on the CI for Tesla T4 GPU (cuda v 13) by quite a large amount, I notice that these might be the only tests with m=1 for mul_mat. On a rented T4 I don't see this problem (cuda v 12.6), so either it's a cuda version thing or an alignment thing

EDIT: it was neither, just a normal bug in not initing buf_iw_gate

@jammm
Copy link
Contributor

jammm commented Oct 19, 2025

This improved perf on ROCm backend as well. Gap to Vulkan is closing with this one!

PS D:\jam\llama.cpp\build\bin> .\llama-bench.exe -ngl 99 -mmp 0 -m .\llama-2-7b.Q4_0.gguf -fa 0,1 --device rocm0,vulkan0
ggml_cuda_init: GGML_CUDA_FORCE_MMQ:    no
ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no
ggml_cuda_init: found 1 ROCm devices:
  Device 0: AMD Radeon RX 9070 XT, gfx1201 (0x1201), VMM: no, Wave Size: 32
ggml_vulkan: Found 1 Vulkan devices:
ggml_vulkan: 0 = AMD Radeon RX 9070 XT (AMD proprietary driver) | uma: 0 | fp16: 1 | bf16: 1 | warp size: 64 | shared memory: 32768 | int dot: 1 | matrix cores: KHR_coopmat

Before PR:

model size params backend ngl fa dev mmap test t/s
llama 7B Q4_0 3.56 GiB 6.74 B ROCm,Vulkan 99 0 ROCm0 0 pp512 4287.33 ± 11.72
llama 7B Q4_0 3.56 GiB 6.74 B ROCm,Vulkan 99 0 ROCm0 0 tg128 120.20 ± 0.38
llama 7B Q4_0 3.56 GiB 6.74 B ROCm,Vulkan 99 1 ROCm0 0 pp512 4320.84 ± 19.01
llama 7B Q4_0 3.56 GiB 6.74 B ROCm,Vulkan 99 1 ROCm0 0 tg128 121.10 ± 0.39
llama 7B Q4_0 3.56 GiB 6.74 B ROCm,Vulkan 99 0 Vulkan0 0 pp512 4007.62 ± 36.73
llama 7B Q4_0 3.56 GiB 6.74 B ROCm,Vulkan 99 0 Vulkan0 0 tg128 133.28 ± 1.25
llama 7B Q4_0 3.56 GiB 6.74 B ROCm,Vulkan 99 1 Vulkan0 0 pp512 3552.97 ± 4.94
llama 7B Q4_0 3.56 GiB 6.74 B ROCm,Vulkan 99 1 Vulkan0 0 tg128 134.93 ± 0.94

After PR:

model size params backend ngl fa dev mmap test t/s
llama 7B Q4_0 3.56 GiB 6.74 B ROCm,Vulkan 99 0 ROCm0 0 pp512 4333.61 ± 19.93
llama 7B Q4_0 3.56 GiB 6.74 B ROCm,Vulkan 99 0 ROCm0 0 tg128 122.93 ± 0.35
llama 7B Q4_0 3.56 GiB 6.74 B ROCm,Vulkan 99 1 ROCm0 0 pp512 4376.61 ± 19.10
llama 7B Q4_0 3.56 GiB 6.74 B ROCm,Vulkan 99 1 ROCm0 0 tg128 124.25 ± 0.34
llama 7B Q4_0 3.56 GiB 6.74 B ROCm,Vulkan 99 0 Vulkan0 0 pp512 4002.04 ± 41.63
llama 7B Q4_0 3.56 GiB 6.74 B ROCm,Vulkan 99 0 Vulkan0 0 tg128 135.46 ± 0.13
llama 7B Q4_0 3.56 GiB 6.74 B ROCm,Vulkan 99 1 Vulkan0 0 pp512 3552.78 ± 4.35
llama 7B Q4_0 3.56 GiB 6.74 B ROCm,Vulkan 99 1 Vulkan0 0 tg128 135.87 ± 0.30

build: e6c48820 (6797)

BTW what else can we do to reduce this CPU overhead? I currently have a branch with a few other optimizations applied from other PR's https://github.com/ggml-org/llama.cpp/compare/master...jammm:llama.cpp:jam/gfx1201?expand=1 though this PR gives the most benefit.

@am17an
Copy link
Collaborator Author

am17an commented Oct 19, 2025

@jammm depends on your model, kernel fusion helps more at the moment for MoE models. BTW sooner or later Vulkan will also adopt a similar thing so it's a shifting goalpost. For AMD specifically @JohannesGaessler recently made some changes in FA, be sure to keep include those as well. Another PR you can try related to CUDA graphs: #16548

@jammm
Copy link
Contributor

jammm commented Oct 19, 2025

@jammm depends on your model, kernel fusion helps more at the moment for MoE models. BTW sooner or later Vulkan will also adopt a similar thing so it's a shifting goalpost. For AMD specifically @JohannesGaessler recently made some changes in FA, be sure to keep include those as well. Another PR you can try related to CUDA graphs: #16548

Thanks! Will take a look at those. You're right re. the shifting goalposts.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ggml changes relating to the ggml tensor library for machine learning Nvidia GPU Issues specific to Nvidia GPUs testing Everything test related

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants