Skip to content

Conversation

JohannesGaessler
Copy link
Collaborator

@JohannesGaessler JohannesGaessler commented Aug 23, 2025

This PR moves the helper code for the MoE MMQ kernel from host code to device code. This eliminates the need for synchronization. It also sets a tighter bound for the maximum number of columns that the kernel selection logic is optimizing for so there is less waste at small batch sizes.

Performance changes
GPU Model Microbatch size Test t/s master t/s cuda-mmq-moe-device-4 Speedup
P40 deepseek2 16B Q4_0 1 pp512 57.74 58.18 1.01
P40 deepseek2 16B Q4_0 2 pp512 78.50 89.00 1.13
P40 deepseek2 16B Q4_0 4 pp512 105.67 128.83 1.22
P40 deepseek2 16B Q4_0 8 pp512 114.87 178.14 1.55
P40 deepseek2 16B Q4_0 16 pp512 165.17 240.73 1.46
P40 deepseek2 16B Q4_0 32 pp512 207.04 246.58 1.19
P40 deepseek2 16B Q4_0 64 pp512 257.65 276.26 1.07
P40 deepseek2 16B Q4_0 128 pp512 311.62 327.51 1.05
P40 deepseek2 16B Q4_0 256 pp512 357.65 379.51 1.06
P40 deepseek2 16B Q4_0 512 pp512 427.84 451.53 1.06
P40 gpt-oss 20B MXFP4 MoE 1 pp512 63.62 63.71 1.00
P40 gpt-oss 20B MXFP4 MoE 2 pp512 71.50 73.77 1.03
P40 gpt-oss 20B MXFP4 MoE 4 pp512 100.22 114.05 1.14
P40 gpt-oss 20B MXFP4 MoE 8 pp512 136.14 173.58 1.28
P40 gpt-oss 20B MXFP4 MoE 16 pp512 187.35 267.64 1.43
P40 gpt-oss 20B MXFP4 MoE 32 pp512 306.10 385.23 1.26
P40 gpt-oss 20B MXFP4 MoE 64 pp512 485.95 498.00 1.02
P40 gpt-oss 20B MXFP4 MoE 128 pp512 708.98 735.89 1.04
P40 gpt-oss 20B MXFP4 MoE 256 pp512 970.32 1023.21 1.05
P40 gpt-oss 20B MXFP4 MoE 512 pp512 1163.90 1243.66 1.07
3x P40 glm4moe 106B.A12B Q4_0 1 pp512 38.02 38.00 1.00
3x P40 glm4moe 106B.A12B Q4_0 2 pp512 30.44 51.45 1.69
3x P40 glm4moe 106B.A12B Q4_0 4 pp512 31.11 70.76 2.27
3x P40 glm4moe 106B.A12B Q4_0 8 pp512 30.99 90.81 2.93
3x P40 glm4moe 106B.A12B Q4_0 16 pp512 45.84 128.51 2.80
3x P40 glm4moe 106B.A12B Q4_0 32 pp512 66.87 126.79 1.90
3x P40 glm4moe 106B.A12B Q4_0 64 pp512 98.59 139.29 1.41
3x P40 glm4moe 106B.A12B Q4_0 128 pp512 144.56 205.02 1.42
3x P40 glm4moe 106B.A12B Q4_0 256 pp512 205.54 281.61 1.37
3x P40 glm4moe 106B.A12B Q4_0 512 pp512 259.91 323.39 1.24
3x P40 gpt-oss 120B MXFP4 MoE 1 pp512 58.20 58.26 1.00
3x P40 gpt-oss 120B MXFP4 MoE 2 pp512 40.39 59.87 1.48
3x P40 gpt-oss 120B MXFP4 MoE 4 pp512 55.51 89.71 1.62
3x P40 gpt-oss 120B MXFP4 MoE 8 pp512 72.68 128.98 1.77
3x P40 gpt-oss 120B MXFP4 MoE 16 pp512 88.47 183.06 2.07
3x P40 gpt-oss 120B MXFP4 MoE 32 pp512 132.16 232.25 1.76
3x P40 gpt-oss 120B MXFP4 MoE 64 pp512 186.98 260.87 1.40
3x P40 gpt-oss 120B MXFP4 MoE 128 pp512 255.21 347.71 1.36
3x P40 gpt-oss 120B MXFP4 MoE 256 pp512 349.38 455.38 1.30
3x P40 gpt-oss 120B MXFP4 MoE 512 pp512 440.57 514.80 1.17
RTX 3090 deepseek2 16B Q4_0 1 pp512 180.89 180.53 1.00
RTX 3090 deepseek2 16B Q4_0 2 pp512 158.03 180.94 1.14
RTX 3090 deepseek2 16B Q4_0 4 pp512 241.87 291.18 1.20
RTX 3090 deepseek2 16B Q4_0 8 pp512 374.09 475.54 1.27
RTX 3090 deepseek2 16B Q4_0 16 pp512 499.68 784.29 1.57
RTX 3090 deepseek2 16B Q4_0 32 pp512 854.65 1209.58 1.42
RTX 3090 deepseek2 16B Q4_0 64 pp512 1268.85 1802.95 1.42
RTX 3090 deepseek2 16B Q4_0 128 pp512 2103.93 2253.15 1.07
RTX 3090 deepseek2 16B Q4_0 256 pp512 3280.19 3628.64 1.11
RTX 3090 deepseek2 16B Q4_0 512 pp512 4260.97 4848.13 1.14
RTX 3090 gpt-oss 20B MXFP4 MoE 1 pp512 176.20 174.48 0.99
RTX 3090 gpt-oss 20B MXFP4 MoE 2 pp512 151.82 161.94 1.07
RTX 3090 gpt-oss 20B MXFP4 MoE 4 pp512 246.36 270.08 1.10
RTX 3090 gpt-oss 20B MXFP4 MoE 8 pp512 363.81 438.64 1.21
RTX 3090 gpt-oss 20B MXFP4 MoE 16 pp512 523.40 717.60 1.37
RTX 3090 gpt-oss 20B MXFP4 MoE 32 pp512 676.41 1102.68 1.63
RTX 3090 gpt-oss 20B MXFP4 MoE 64 pp512 1130.60 1548.65 1.37
RTX 3090 gpt-oss 20B MXFP4 MoE 128 pp512 1925.10 1941.23 1.01
RTX 3090 gpt-oss 20B MXFP4 MoE 256 pp512 2877.88 2951.21 1.03
RTX 3090 gpt-oss 20B MXFP4 MoE 512 pp512 3726.23 3894.64 1.05
RTX 4090 deepseek2 16B Q4_0 1 pp512 242.11 247.38 1.02
RTX 4090 deepseek2 16B Q4_0 2 pp512 219.70 270.47 1.23
RTX 4090 deepseek2 16B Q4_0 4 pp512 353.59 442.09 1.25
RTX 4090 deepseek2 16B Q4_0 8 pp512 581.05 726.42 1.25
RTX 4090 deepseek2 16B Q4_0 16 pp512 850.81 1225.74 1.44
RTX 4090 deepseek2 16B Q4_0 32 pp512 1500.40 2055.06 1.37
RTX 4090 deepseek2 16B Q4_0 64 pp512 2324.07 3441.40 1.48
RTX 4090 deepseek2 16B Q4_0 128 pp512 3942.70 4842.43 1.23
RTX 4090 deepseek2 16B Q4_0 256 pp512 6207.43 8172.86 1.32
RTX 4090 deepseek2 16B Q4_0 512 pp512 8440.21 11711.92 1.39
RTX 4090 gpt-oss 20B MXFP4 MoE 1 pp512 270.94 275.25 1.02
RTX 4090 gpt-oss 20B MXFP4 MoE 2 pp512 215.82 254.63 1.18
RTX 4090 gpt-oss 20B MXFP4 MoE 4 pp512 370.00 439.38 1.19
RTX 4090 gpt-oss 20B MXFP4 MoE 8 pp512 583.13 738.66 1.27
RTX 4090 gpt-oss 20B MXFP4 MoE 16 pp512 885.30 1228.62 1.39
RTX 4090 gpt-oss 20B MXFP4 MoE 32 pp512 1252.03 2024.84 1.62
RTX 4090 gpt-oss 20B MXFP4 MoE 64 pp512 2158.26 3026.16 1.40
RTX 4090 gpt-oss 20B MXFP4 MoE 128 pp512 3836.36 4143.91 1.08
RTX 4090 gpt-oss 20B MXFP4 MoE 256 pp512 5829.64 6404.23 1.10
RTX 4090 gpt-oss 20B MXFP4 MoE 512 pp512 7845.31 8871.90 1.13
3x RTX 4090 glm4moe 106B.A12B Q4_0 1 pp512 132.35 132.37 1.00
3x RTX 4090 glm4moe 106B.A12B Q4_0 2 pp512 79.11 136.01 1.72
3x RTX 4090 glm4moe 106B.A12B Q4_0 4 pp512 116.05 214.34 1.85
3x RTX 4090 glm4moe 106B.A12B Q4_0 8 pp512 164.96 322.32 1.95
3x RTX 4090 glm4moe 106B.A12B Q4_0 16 pp512 199.79 473.61 2.37
3x RTX 4090 glm4moe 106B.A12B Q4_0 32 pp512 321.17 721.42 2.25
3x RTX 4090 glm4moe 106B.A12B Q4_0 64 pp512 533.66 1114.73 2.09
3x RTX 4090 glm4moe 106B.A12B Q4_0 128 pp512 875.94 1363.58 1.56
3x RTX 4090 glm4moe 106B.A12B Q4_0 256 pp512 1351.89 2106.57 1.56
3x RTX 4090 glm4moe 106B.A12B Q4_0 512 pp512 1799.77 2679.69 1.49
3x RTX 4090 gpt-oss 120B MXFP4 MoE 1 pp512 254.17 254.57 1.00
3x RTX 4090 gpt-oss 120B MXFP4 MoE 2 pp512 124.66 215.40 1.73
3x RTX 4090 gpt-oss 120B MXFP4 MoE 4 pp512 202.60 354.83 1.75
3x RTX 4090 gpt-oss 120B MXFP4 MoE 8 pp512 297.25 557.90 1.88
3x RTX 4090 gpt-oss 120B MXFP4 MoE 16 pp512 415.05 870.82 2.10
3x RTX 4090 gpt-oss 120B MXFP4 MoE 32 pp512 529.60 1272.02 2.40
3x RTX 4090 gpt-oss 120B MXFP4 MoE 64 pp512 836.27 1657.87 1.98
3x RTX 4090 gpt-oss 120B MXFP4 MoE 128 pp512 1304.74 1914.18 1.47
3x RTX 4090 gpt-oss 120B MXFP4 MoE 256 pp512 1975.84 2793.96 1.41
3x RTX 4090 gpt-oss 120B MXFP4 MoE 512 pp512 2739.94 3583.90 1.31
RX 6800 deepseek2 16B Q4_0 1 pp512 52.35 52.61 1.00
RX 6800 deepseek2 16B Q4_0 2 pp512 56.66 73.42 1.30
RX 6800 deepseek2 16B Q4_0 4 pp512 84.59 114.84 1.36
RX 6800 deepseek2 16B Q4_0 8 pp512 93.89 164.02 1.75
RX 6800 deepseek2 16B Q4_0 16 pp512 133.09 191.05 1.44
RX 6800 deepseek2 16B Q4_0 32 pp512 169.42 226.48 1.34
RX 6800 deepseek2 16B Q4_0 64 pp512 227.17 237.71 1.05
RX 6800 deepseek2 16B Q4_0 128 pp512 287.05 280.27 0.98
RX 6800 deepseek2 16B Q4_0 256 pp512 346.76 364.23 1.05
RX 6800 deepseek2 16B Q4_0 512 pp512 432.14 451.29 1.04
RX 6800 gpt-oss 20B MXFP4 MoE 1 pp512 75.12 75.12 1.00
RX 6800 gpt-oss 20B MXFP4 MoE 2 pp512 57.74 70.95 1.23
RX 6800 gpt-oss 20B MXFP4 MoE 4 pp512 83.53 109.49 1.31
RX 6800 gpt-oss 20B MXFP4 MoE 8 pp512 108.09 157.44 1.46
RX 6800 gpt-oss 20B MXFP4 MoE 16 pp512 140.85 232.70 1.65
RX 6800 gpt-oss 20B MXFP4 MoE 32 pp512 231.43 312.12 1.35
RX 6800 gpt-oss 20B MXFP4 MoE 64 pp512 366.12 374.70 1.02
RX 6800 gpt-oss 20B MXFP4 MoE 128 pp512 531.25 546.71 1.03
RX 6800 gpt-oss 20B MXFP4 MoE 256 pp512 689.86 715.34 1.04
RX 6800 gpt-oss 20B MXFP4 MoE 512 pp512 757.45 785.62 1.04

@IMbackK if you could also check performance that would be appreciated.

@github-actions github-actions bot added Nvidia GPU Issues specific to Nvidia GPUs ggml changes relating to the ggml tensor library for machine learning labels Aug 23, 2025
@JohannesGaessler
Copy link
Collaborator Author

Looking at the code I realized that it's possible to reduce the number of tiles/CUDA blocks that need to be considered for MoE. These tiles/CUDA blocks would be skipped regardless but you can squeeze out a few more % if you never need to do this check in the first place. (I updated the table in the OP.)

Copy link
Collaborator

@IMbackK IMbackK left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On CDNA, this pr casues a crash or failure (random) of various tests such as ./bin/test-backend-ops -p type_a=q8_0,type_b=f32,n_mats=4,n_used=2,b=0,m=512,n=32,k=256,o=1 as well as crashes in model inference.

template<int width = WARP_SIZE>
static __device__ __forceinline__ int warp_reduce_all(int x) {
#ifdef GGML_USE_HIP
#ifndef GGML_USE_HIP
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why resitct this to cuda? might as well use __all on hip and check for ggml_cuda_get_physical_warp_size, even if we never use this with width == 64, that should still help rdna.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just wasn't aware of the __all instruction.


template<int width = WARP_SIZE>
static __device__ __forceinline__ int warp_reduce_any(int x) {
#ifndef GGML_USE_HIP
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same

@JohannesGaessler
Copy link
Collaborator Author

I think the problem was that I was misusing the warp reduce functions. I forgot that the default width is WARP_SIZE, not ggml_cuda_get_physical_warp_size().

@IMbackK
Copy link
Collaborator

IMbackK commented Aug 23, 2025

jup warp_reduce_any<warp_size> is needed.

@IMbackK
Copy link
Collaborator

IMbackK commented Aug 23, 2025

should be ok now i think, let me check.

@IMbackK
Copy link
Collaborator

IMbackK commented Aug 23, 2025

Correctness is fine and performance looks fine too.

CDNA:

Model Microbatch size Test t/s master t/s cuda-mmq-moe-device-4 Speedup
gpt-oss 20B MXFP4 MoE 1 pp512 83.03 83.30 1.00
gpt-oss 20B MXFP4 MoE 4 pp512 100.06 107.98 1.08
gpt-oss 20B MXFP4 MoE 8 pp512 151.54 181.60 1.20
gpt-oss 20B MXFP4 MoE 64 pp512 435.87 591.27 1.36
gpt-oss 20B MXFP4 MoE 256 pp512 1229.95 1229.05 1.00
gpt-oss 20B MXFP4 MoE 512 pp512 1930.04 1928.51 1.00

@JohannesGaessler
Copy link
Collaborator Author

I'm noticing that for very large batch sizes there are issues with shared memory limits. On Pascal it should be -ub 6144, on Turing it should be -ub 8192, on Ampere/Ada Lovelace it should be -ub 12672. Would probably make sense to try and refactor the code to avoid this.

@JohannesGaessler
Copy link
Collaborator Author

Previously I was storing the token position in the physical batch and the expert index as 32 bit integers. However, that many bits are not needed. If you store the token position with 22 bits and the expert index with 10 bits you only need half as much shared memory and the maximum physical batch size becomes twice as high as I outlined in my previous post, which I think will be enough. The assigned bits will only be insufficient for a physical batch size > 4M or more than 1024 used experts, which should be well above the values that the code needs to run for.

@JohannesGaessler JohannesGaessler merged commit 5eff6ec into ggml-org:master Aug 25, 2025
48 checks passed
Minh141120 pushed a commit to menloresearch/llama.cpp that referenced this pull request Aug 26, 2025
* CUDA: MoE helper in device code, better tile sizes

* reduce superfluous CUDA blocks
Minh141120 pushed a commit to menloresearch/llama.cpp that referenced this pull request Aug 27, 2025
* CUDA: MoE helper in device code, better tile sizes

* reduce superfluous CUDA blocks
Nexesenex added a commit to Nexesenex/croco.cpp that referenced this pull request Oct 6, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ggml changes relating to the ggml tensor library for machine learning Nvidia GPU Issues specific to Nvidia GPUs

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants