Skip to content

Conversation

am17an
Copy link
Collaborator

@am17an am17an commented Sep 3, 2025

Add support for mul_mat_id for bs < 16. This kernel works by calculating activations for each token for each expert, except it filters expert which are not used in {n_expert_used, token}, so does more compute but overall is better than the existing path for bs <= 16

on a 4090

Backend GGML op Op parameters TFLOPS master TFLOPS mmf2 Speedup
CUDA0 MUL_MAT_ID type_a=f16,type_b=f32,n_mats=128,n_used=8,b=0,m=768,n=1,k=2048,o=1 2.32 2.31 1.00
CUDA0 MUL_MAT_ID type_a=f16,type_b=f32,n_mats=128,n_used=8,b=0,m=768,n=4,k=2048,o=1 0.26 1.25 4.84
CUDA0 MUL_MAT_ID type_a=f16,type_b=f32,n_mats=128,n_used=8,b=0,m=768,n=512,k=2048,o=1 4.50 4.73 1.05
CUDA0 MUL_MAT_ID type_a=f16,type_b=f32,n_mats=128,n_used=8,b=0,m=768,n=8,k=2048,o=1 0.30 1.14 3.77
CUDA0 MUL_MAT_ID type_a=f32,type_b=f32,n_mats=128,n_used=8,b=0,m=768,n=1,k=2048,o=1 1.75 1.74 0.99
CUDA0 MUL_MAT_ID type_a=f32,type_b=f32,n_mats=128,n_used=8,b=0,m=768,n=4,k=2048,o=1 0.36 0.55 1.52
CUDA0 MUL_MAT_ID type_a=f32,type_b=f32,n_mats=128,n_used=8,b=0,m=768,n=512,k=2048,o=1 5.03 5.04 1.00
CUDA0 MUL_MAT_ID type_a=f32,type_b=f32,n_mats=128,n_used=8,b=0,m=768,n=8,k=2048,o=1 0.40 0.55 1.37

Also, the normal mmf path is not affected

Backend GGML op Op parameters TFLOPS master TFLOPS mul_mat_id_mmf Speedup
CUDA0 MUL_MAT type_a=f16,type_b=f32,m=128,n=1,k=16416,bs=[8,1],nr=[4,1],per=[0,1,2,3],v=1,o=1 2.42 2.43 1.01
CUDA0 MUL_MAT type_a=f16,type_b=f32,m=16416,n=1,k=128,bs=[8,1],nr=[4,1],per=[0,2,1,3],v=0,o=1 4.45 4.47 1.00
CUDA0 MUL_MAT type_a=f16,type_b=f32,m=4096,n=1,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0,o=1 0.97 0.97 1.00
CUDA0 MUL_MAT type_a=f16,type_b=f32,m=4096,n=2,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0,o=1 1.92 1.92 1.00
CUDA0 MUL_MAT type_a=f16,type_b=f32,m=4096,n=3,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0,o=1 2.88 2.88 1.00
CUDA0 MUL_MAT type_a=f16,type_b=f32,m=4096,n=4,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0,o=1 3.84 3.84 1.00
CUDA0 MUL_MAT type_a=f16,type_b=f32,m=4096,n=5,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0,o=1 4.79 4.79 1.00
CUDA0 MUL_MAT type_a=f16,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0,o=1 173.70 174.00 1.00
CUDA0 MUL_MAT type_a=f16,type_b=f32,m=4096,n=8,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0,o=1 7.63 7.63 1.00
CUDA0 MUL_MAT type_a=f32,type_b=f32,m=4096,n=1,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0,o=1 0.49 0.49 1.00
CUDA0 MUL_MAT type_a=f32,type_b=f32,m=4096,n=2,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0,o=1 0.97 0.97 1.00
CUDA0 MUL_MAT type_a=f32,type_b=f32,m=4096,n=3,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0,o=1 1.46 1.46 1.00
CUDA0 MUL_MAT type_a=f32,type_b=f32,m=4096,n=4,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0,o=1 1.94 1.94 1.00
CUDA0 MUL_MAT type_a=f32,type_b=f32,m=4096,n=5,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0,o=1 2.42 2.42 1.00
CUDA0 MUL_MAT type_a=f32,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0,o=1 81.94 82.32 1.00
CUDA0 MUL_MAT type_a=f32,type_b=f32,m=4096,n=8,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0,o=1 3.87 3.87 1.00

@github-actions github-actions bot added testing Everything test related Nvidia GPU Issues specific to Nvidia GPUs ggml changes relating to the ggml tensor library for machine learning labels Sep 3, 2025
@am17an am17an changed the title CUDA: Add mul_mat_id support the mmf CUDA: Add mul_mat_id support for the mmf kernel Sep 3, 2025
Copy link
Collaborator

@JohannesGaessler JohannesGaessler left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I understand your implementation correctly you are launching CUDA blocks per expert and number of used experts. The implementation I had in mind would be to launch blocks per expert, and to iterate over the number of used experts. Otherwise the iteration over the expert data (the expensive part) will scale with the number of used experts.

@am17an
Copy link
Collaborator Author

am17an commented Sep 4, 2025

If I understand your implementation correctly you are launching CUDA blocks per expert and number of used experts. The implementation I had in mind would be to launch blocks per expert, and to iterate over the number of used experts. Otherwise the iteration over the expert data (the expensive part) will scale with the number of used experts.

It launches n_expert * n_expert_used blocks. Each block first scans ids and if no token assigns that expert in that used slot, it exits immediately. If any token matches, the block computes the full matmul for all tokens and writes only matching positions. So the expensive work scales with the number of (slot, expert) pairs present in the batch and the number of tokens, not with number of used experts. Did you have something else in mind?

@JohannesGaessler
Copy link
Collaborator

At a given token position each expert is guaranteed to appear either 0 or 1 time in ids. Initialize arrays for the src1 and dst indices of the expert with -1. Iterate over ids, set the indices that should be loaded from/written to. If the expert does not appear in ids just set the src1 index to 0, I think that will be faster than checking whether the data should be loaded. At the end, only write back the data where the dst indices are not -1.

So the expensive work scales with the number of (slot, expert) pairs present in the batch and the number of tokens, not with number of used experts.

For multiple token positions an expert is not guaranteed to appear at the same index in ids. With this implementation, if an expert is used 2 times at index 0 and 1 time at index 1, 2 CUDA blocks will execute the matrix multiplication with that expert while wasting 13/16 or 29/32 of the compute. If you consolidate the 3 tokens into a single CUDA block only 5/8 or 13/16 would be wasted.

@am17an
Copy link
Collaborator Author

am17an commented Sep 4, 2025

New performance numbers on a 3090 (there is some unexplained variability +/- 5 in my card)

Backend GGML op Op parameters TFLOPS master TFLOPS mul_mat_id_for_mmf Speedup
CUDA0 MUL_MAT_ID type_a=f16,type_b=f32,n_mats=128,n_used=8,b=0,m=768,n=1,k=2048,o=1 0.80 0.81 1.02
CUDA0 MUL_MAT_ID type_a=f16,type_b=f32,n_mats=128,n_used=8,b=0,m=768,n=4,k=2048,o=1 0.16 0.84 5.15
CUDA0 MUL_MAT_ID type_a=f16,type_b=f32,n_mats=128,n_used=8,b=0,m=768,n=512,k=2048,o=1 4.50 4.56 1.01
CUDA0 MUL_MAT_ID type_a=f16,type_b=f32,n_mats=128,n_used=8,b=0,m=768,n=8,k=2048,o=1 0.20 1.00 5.06
CUDA0 MUL_MAT_ID type_a=f32,type_b=f32,n_mats=128,n_used=8,b=0,m=768,n=1,k=2048,o=1 0.41 0.42 1.01
CUDA0 MUL_MAT_ID type_a=f32,type_b=f32,n_mats=128,n_used=8,b=0,m=768,n=4,k=2048,o=1 0.19 0.49 2.54
CUDA0 MUL_MAT_ID type_a=f32,type_b=f32,n_mats=128,n_used=8,b=0,m=768,n=512,k=2048,o=1 4.14 4.34 1.05
CUDA0 MUL_MAT_ID type_a=f32,type_b=f32,n_mats=128,n_used=8,b=0,m=768,n=8,k=2048,o=1 0.29 0.51 1.77

@JohannesGaessler
Copy link
Collaborator

there is some unexplained variability +/- 5 in my card

Probably because test-backend-ops uses unseeded random data, so the number of active experts is variable.

Copy link
Collaborator

@JohannesGaessler JohannesGaessler left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could be optimized further by using more warps for the check, passing the number of used experts as a template parameter, or modifying the code for loading src1. But this does not need to be done in this PR.

Please add a function like mul_mat_f_switch_ids instead of one conditional statement for each nwarps value.

@JohannesGaessler
Copy link
Collaborator

The code crashes for batch sizes % 4 != 0:

========= Invalid __shared__ read of size 16 bytes                                                     
=========     at void ggml_cuda_mma::load_ldmatrix<__half2>(ggml_cuda_mma::tile<(int)16, (int)8, T1> &,
 const T1 *, int)+0x11e0 in mma.cuh:278                                                                
=========     by thread (3,1,0) in block (5,0,0)                                                       
=========     Access at 0xaf8 is misaligned                                                            
=========         Device Frame: void mul_mat_f<__half2, (int)32, (int)14, (int)8, (bool)1>(const T1 *, 
const float *, const int *, float *, int, int, int, int, int, int, int, int, int, int, int, int, int, i
nt, int, int, int)+0x10c0 in mmf.cu:91                                                                 
=========     Saved host backtrace up to driver entry point at kernel launch time                      
=========         Host Frame: ggml_cuda_mul_mat_f(ggml_backend_cuda_context&, ggml_tensor const*, ggml_
tensor const*, ggml_tensor const*, ggml_tensor*) [0x21c4bae] in libggml-cuda.so                        
=========         Host Frame: ggml_backend_cuda_graph_compute(ggml_backend*, ggml_cgraph*) [0x210427c] 
in libggml-cuda.so                                                                                     
=========         Host Frame: ggml_backend_sched_graph_compute_async [0x2dfa2] in libggml-base.so      
=========         Host Frame: llama_context::graph_compute(ggml_cgraph*, bool) [0x9340f] in libllama.so
=========         Host Frame: llama_context::process_ubatch(llama_ubatch const&, llm_graph_type, llama_
memory_context_i*, ggml_status&) [0x94c92] in libllama.so                                              
=========         Host Frame: llama_context::decode(llama_batch const&) [0x99bee] in libllama.so       
=========         Host Frame: llama_decode [0x9abfd] in libllama.so                                    
=========         Host Frame: main [0x48e0c] in cli                                                    

The problem is that you're shifting the shared memory for the compute by an amount not divisible by 16 bytes. You can fix it by padding the shared memory for the ids. Or, since the shared memory is being padded anyways, you could write the ids into the padding (this way you would also not need any extra shared memory though the difference should be negligible).

@JohannesGaessler
Copy link
Collaborator

Performance
GPU Model Microbatch size Test t/s b6367 t/s 54f6a56 Speedup
2x RTX 4090 deepseek2 16B F16 1 pp512 269.94 269.76 1.00
2x RTX 4090 deepseek2 16B F16 2 pp512 130.55 356.30 2.73
2x RTX 4090 deepseek2 16B F16 3 pp512 164.22 452.47 2.76
2x RTX 4090 deepseek2 16B F16 4 pp512 193.13 548.85 2.84
2x RTX 4090 deepseek2 16B F16 5 pp512 222.08 629.30 2.83
2x RTX 4090 deepseek2 16B F16 6 pp512 249.16 715.48 2.87
2x RTX 4090 deepseek2 16B F16 7 pp512 272.17 782.85 2.88
2x RTX 4090 deepseek2 16B F16 8 pp512 294.58 876.95 2.98
2x RTX 4090 deepseek2 16B F16 9 pp512 309.33 896.82 2.90
2x RTX 4090 deepseek2 16B F16 10 pp512 341.48 986.71 2.89
2x RTX 4090 deepseek2 16B F16 11 pp512 364.46 1057.13 2.90
2x RTX 4090 deepseek2 16B F16 12 pp512 391.97 1125.30 2.87
2x RTX 4090 deepseek2 16B F16 13 pp512 395.94 1147.07 2.90
2x RTX 4090 deepseek2 16B F16 14 pp512 427.10 1229.87 2.88
2x RTX 4090 deepseek2 16B F16 15 pp512 445.26 1286.09 2.89
2x RTX 4090 deepseek2 16B F16 16 pp512 456.93 1369.46 3.00

mmf.cu has now become the CUDA file with the longest compilation duration, it should be split across multiple files like has been done with e.g. MMQ or the FA kernels. Preferably this should still be done in this PR to avoid stalling the CI.

I eluded to further performance optimizations further up, do you intend to work on those? If not I will make a follow-up PR.

@am17an
Copy link
Collaborator Author

am17an commented Sep 5, 2025

I'll do the optimisations too, after this PR. One thing I haven't done also is check how fast this is compared to the baseline for a higher token count, I think for f16 it will a quite a lot higher than 16. However that will really blow up the compile time.

@github-actions github-actions bot added the python python script changes label Sep 7, 2025
@am17an am17an force-pushed the mul_mat_id_for_mmf branch from 73c3bbf to bee77df Compare September 7, 2025 02:51
@am17an
Copy link
Collaborator Author

am17an commented Sep 7, 2025

@JohannesGaessler please take a look when you get time

@JohannesGaessler
Copy link
Collaborator

On my system this does not work to reduce the compilation time. The bulk of the compilation is still done for mmf.cu with no noticeable reduction in compilation time and size of the .cu.o file.

@am17an am17an force-pushed the mul_mat_id_for_mmf branch from 311d34d to e6528a3 Compare September 8, 2025 14:35
@am17an am17an force-pushed the mul_mat_id_for_mmf branch from e6528a3 to 43d2bc4 Compare September 8, 2025 14:53
@am17an
Copy link
Collaborator Author

am17an commented Sep 8, 2025

I divided the templates according to ncols_dst, following are the compile times for .cu.o files (recorded on a AMD ryzen 3080 @ 4 Ghz)

Rank Time (ms) File
1 75525 template-instances/mmq-instance-q2_k.cu.o
2 56127 template-instances/mmq-instance-q3_k.cu.o
3 55389 template-instances/mmq-instance-iq2_s.cu.o
4 54852 template-instances/mmq-instance-q6_k.cu.o
5 51608 template-instances/mmq-instance-iq2_xs.cu.o
6 51504 template-instances/mmq-instance-q5_1.cu.o
7 50885 template-instances/mmq-instance-q5_0.cu.o
8 49772 template-instances/mmq-instance-q5_k.cu.o
9 48322 template-instances/mmq-instance-q4_k.cu.o
10 48236 template-instances/mmq-instance-q4_1.cu.o
11 47534 template-instances/mmq-instance-iq3_s.cu.o
12 47435 template-instances/mmq-instance-iq1_s.cu.o
13 47040 template-instances/mmq-instance-mxfp4.cu.o
14 46204 template-instances/mmq-instance-q4_0.cu.o
15 44698 template-instances/mmq-instance-iq3_xxs.cu.o
16 44621 template-instances/mmq-instance-iq2_xxs.cu.o
17 44421 template-instances/mmq-instance-iq4_nl.cu.o
18 44396 template-instances/mmq-instance-iq4_xs.cu.o
19 42360 template-instances/mmq-instance-q8_0.cu.o
20 37690 template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_2.cu.o
21 36769 mmvf.cu.o
22 36735 template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu.o
23 36461 template-instances/fattn-mma-f16-instance-ncols1_64-ncols2_1.cu.o
24 36135 mmvq.cu.o
25 35278 template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu.o
26 26949 template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_1.cu.o
27 26903 template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu.o
28 26831 template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu.o
29 25802 template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_2.cu.o
30 24504 template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu.o
31 24328 template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_2.cu.o
32 24285 template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_1.cu.o
33 23444 template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu.o
34 21089 template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_1.cu.o
35 20548 template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_2.cu.o
36 20465 template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu.o
37 20010 template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu.o
38 17262 ssm-scan.cu.o
39 16099 mean.cu.o
40 15953 sum.cu.o
41 14640 template-instances/mmf-instance-ncols_16.cu.o
42 14433 template-instances/mmf-instance-ncols_15.cu.o
43 14076 template-instances/mmf-instance-ncols_13.cu.o
44 14051 template-instances/mmf-instance-ncols_14.cu.o
45 13688 template-instances/mmf-instance-ncols_10.cu.o
46 13671 template-instances/mmf-instance-ncols_11.cu.o
47 13565 template-instances/mmf-instance-ncols_12.cu.o
48 13022 template-instances/mmf-instance-ncols_9.cu.o
49 12612 template-instances/mmf-instance-ncols_7.cu.o
50 12515 template-instances/mmf-instance-ncols_8.cu.o
51 12385 binbcast.cu.o
52 12195 template-instances/mmf-instance-ncols_6.cu.o
53 12108 template-instances/mmf-instance-ncols_5.cu.o
54 12039 template-instances/mmf-instance-ncols_4.cu.o
55 11827 template-instances/mmf-instance-ncols_3.cu.o
56 11568 template-instances/mmf-instance-ncols_2.cu.o
57 11475 template-instances/mmf-instance-ncols_1.cu.o
58 11457 template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_16.cu.o
59 11212 template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_16.cu.o
60 10807 template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_16.cu.o
61 10280 fattn-wmma-f16.cu.o
62 9414 template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_0.cu.o
63 9345 ggml-cuda.cu.o
64 9141 template-instances/fattn-vec-f16-instance-hs256-f16-f16.cu.o
65 8986 template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_0.cu.o
66 8893 template-instances/fattn-vec-f16-instance-hs128-q8_0-q8_0.cu.o
67 8396 template-instances/fattn-vec-f32-instance-hs128-q8_0-q8_0.cu.o
68 8169 template-instances/fattn-vec-f32-instance-hs256-f16-f16.cu.o
69 7423 template-instances/fattn-vec-f16-instance-hs128-f16-f16.cu.o
70 7037 convert.cu.o
71 6697 fattn-tile.cu.o
72 6404 unary.cu.o
73 6363 getrows.cu.o
74 6336 template-instances/fattn-vec-f32-instance-hs128-f16-f16.cu.o
75 6290 cpy.cu.o
76 5962 softmax.cu.o
77 5877 rope.cu.o
78 5834 template-instances/fattn-vec-f16-instance-hs64-f16-f16.cu.o
79 5802 norm.cu.o
80 5748 wkv.cu.o
81 5530 mmq.cu.o
82 5414 set-rows.cu.o
83 5266 template-instances/fattn-vec-f32-instance-hs64-f16-f16.cu.o
84 4978 concat.cu.o
85 4955 im2col.cu.o
86 4939 acc.cu.o
87 4895 argmax.cu.o
88 4893 argsort.cu.o
89 4892 quantize.cu.o
90 4846 fattn.cu.o
91 4842 gla.cu.o
92 4829 arange.cu.o
93 4827 cross-entropy-loss.cu.o
94 4820 add-id.cu.o
95 4811 clamp.cu.o
96 4774 conv2d-dw.cu.o
97 4749 ssm-conv.cu.o
98 4700 conv2d.cu.o
99 4687 mmf.cu.o
100 4656 conv2d-transpose.cu.o
101 4648 upscale.cu.o
102 4627 conv-transpose-1d.cu.o
103 4602 opt-step-adamw.cu.o
104 4582 pad_reflect_1d.cu.o
105 4565 count-equal.cu.o
106 4520 pad.cu.o
107 4509 tsembd.cu.o
108 4502 sumrows.cu.o
109 4498 diagmask.cu.o
110 4477 pool2d.cu.o
111 4461 out-prod.cu.o
112 4451 opt-step-sgd.cu.o
113 4451 scale.cu.o
114 4411 softcap.cu.o
115 4394 roll.cu.o

@am17an am17an force-pushed the mul_mat_id_for_mmf branch from 43d2bc4 to bc12fd1 Compare September 8, 2025 15:26
Copy link
Collaborator

@JohannesGaessler JohannesGaessler left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you forgot to add the new files generated by generate_cu_files.py to git.

Copy link
Collaborator

@JohannesGaessler JohannesGaessler left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you, this is a pretty good speedup.

Performance changes
GPU Model Microbatch size Test t/s b0d5299 t/s 184e42d Speedup
RTX 3090 granitemoe 3B F16 1 pp512 221.91 221.29 1.00
RTX 3090 granitemoe 3B F16 2 pp512 121.58 366.75 3.02
RTX 3090 granitemoe 3B F16 3 pp512 162.26 504.86 3.11
RTX 3090 granitemoe 3B F16 4 pp512 199.52 626.74 3.14
RTX 3090 granitemoe 3B F16 5 pp512 231.65 740.41 3.20
RTX 3090 granitemoe 3B F16 6 pp512 268.74 862.74 3.21
RTX 3090 granitemoe 3B F16 7 pp512 300.47 966.34 3.22
RTX 3090 granitemoe 3B F16 8 pp512 335.04 1089.98 3.25
RTX 3090 granitemoe 3B F16 9 pp512 355.31 1130.60 3.18
RTX 3090 granitemoe 3B F16 10 pp512 383.52 1226.65 3.20
RTX 3090 granitemoe 3B F16 11 pp512 416.42 1323.06 3.18
RTX 3090 granitemoe 3B F16 12 pp512 447.67 1419.12 3.17
RTX 3090 granitemoe 3B F16 13 pp512 470.12 1491.19 3.17
RTX 3090 granitemoe 3B F16 14 pp512 499.06 1579.86 3.17
RTX 3090 granitemoe 3B F16 15 pp512 527.34 1649.31 3.13
RTX 3090 granitemoe 3B F16 16 pp512 555.90 1744.24 3.14
RTX 4090 granitemoe 3B F16 1 pp512 313.27 313.36 1.00
RTX 4090 granitemoe 3B F16 2 pp512 123.59 444.67 3.60
RTX 4090 granitemoe 3B F16 3 pp512 149.77 609.81 4.07
RTX 4090 granitemoe 3B F16 4 pp512 180.45 756.26 4.19
RTX 4090 granitemoe 3B F16 5 pp512 206.47 890.42 4.31
RTX 4090 granitemoe 3B F16 6 pp512 235.36 1035.59 4.40
RTX 4090 granitemoe 3B F16 7 pp512 261.88 1156.39 4.42
RTX 4090 granitemoe 3B F16 8 pp512 289.49 1300.09 4.49
RTX 4090 granitemoe 3B F16 9 pp512 306.20 1363.93 4.45
RTX 4090 granitemoe 3B F16 10 pp512 333.31 1486.41 4.46
RTX 4090 granitemoe 3B F16 11 pp512 352.67 1605.19 4.55
RTX 4090 granitemoe 3B F16 12 pp512 375.57 1720.24 4.58
RTX 4090 granitemoe 3B F16 13 pp512 394.11 1814.92 4.61
RTX 4090 granitemoe 3B F16 14 pp512 417.75 1927.71 4.61
RTX 4090 granitemoe 3B F16 15 pp512 435.70 2029.94 4.66
RTX 4090 granitemoe 3B F16 16 pp512 457.82 2154.04 4.71
2x RTX 4090 deepseek2 16B F16 1 pp512 273.92 273.86 1.00
2x RTX 4090 deepseek2 16B F16 2 pp512 131.71 361.76 2.75
2x RTX 4090 deepseek2 16B F16 3 pp512 164.75 458.22 2.78
2x RTX 4090 deepseek2 16B F16 4 pp512 193.37 554.75 2.87
2x RTX 4090 deepseek2 16B F16 5 pp512 222.33 635.67 2.86
2x RTX 4090 deepseek2 16B F16 6 pp512 250.58 722.15 2.88
2x RTX 4090 deepseek2 16B F16 7 pp512 273.81 789.98 2.89
2x RTX 4090 deepseek2 16B F16 8 pp512 296.32 884.05 2.98
2x RTX 4090 deepseek2 16B F16 9 pp512 310.39 903.10 2.91
2x RTX 4090 deepseek2 16B F16 10 pp512 341.71 992.85 2.91
2x RTX 4090 deepseek2 16B F16 11 pp512 364.37 1064.10 2.92
2x RTX 4090 deepseek2 16B F16 12 pp512 392.72 1132.44 2.88
2x RTX 4090 deepseek2 16B F16 13 pp512 396.68 1154.09 2.91
2x RTX 4090 deepseek2 16B F16 14 pp512 427.56 1237.19 2.89
2x RTX 4090 deepseek2 16B F16 15 pp512 445.56 1293.89 2.90
2x RTX 4090 deepseek2 16B F16 16 pp512 458.10 1377.46 3.01

@am17an am17an force-pushed the mul_mat_id_for_mmf branch from fda19db to 9db0ac4 Compare September 9, 2025 02:35
@am17an am17an merged commit a972fae into ggml-org:master Sep 9, 2025
50 checks passed
@am17an am17an deleted the mul_mat_id_for_mmf branch September 9, 2025 06:38
njsyw1997 pushed a commit to aizip/llama.cpp that referenced this pull request Sep 10, 2025
* CUDA: Add mul_mat_id support the mmf

Add support for mul_mat_id for bs < 16

* Review: use warp_size, fix should_use_mmf condition

* Launch one block per expert, stride along n_expert_used

* templatize mul_mat_id

* Pad shmem to 16 bytes, add helper function mul_mat_f_switch_ids

* Reduce compile times by dividing mmf into f16, bf16 and f32 variants

* Divide mmf by ncols_dst

* Add missing files

* Fix MUSA/HIP builds
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ggml changes relating to the ggml tensor library for machine learning Nvidia GPU Issues specific to Nvidia GPUs python python script changes testing Everything test related

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants