-
Notifications
You must be signed in to change notification settings - Fork 13.4k
CUDA: Add mul_mat_id support for the mmf kernel #15767
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If I understand your implementation correctly you are launching CUDA blocks per expert and number of used experts. The implementation I had in mind would be to launch blocks per expert, and to iterate over the number of used experts. Otherwise the iteration over the expert data (the expensive part) will scale with the number of used experts.
It launches n_expert * n_expert_used blocks. Each block first scans ids and if no token assigns that expert in that used slot, it exits immediately. If any token matches, the block computes the full matmul for all tokens and writes only matching positions. So the expensive work scales with the number of (slot, expert) pairs present in the batch and the number of tokens, not with number of used experts. Did you have something else in mind? |
At a given token position each expert is guaranteed to appear either 0 or 1 time in
For multiple token positions an expert is not guaranteed to appear at the same index in |
New performance numbers on a 3090 (there is some unexplained variability +/- 5 in my card)
|
3005969
to
7018fe8
Compare
Probably because |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This could be optimized further by using more warps for the check, passing the number of used experts as a template parameter, or modifying the code for loading src1
. But this does not need to be done in this PR.
Please add a function like mul_mat_f_switch_ids
instead of one conditional statement for each nwarps
value.
The code crashes for batch sizes % 4 != 0:
The problem is that you're shifting the shared memory for the compute by an amount not divisible by 16 bytes. You can fix it by padding the shared memory for the ids. Or, since the shared memory is being padded anyways, you could write the ids into the padding (this way you would also not need any extra shared memory though the difference should be negligible). |
Performance
I eluded to further performance optimizations further up, do you intend to work on those? If not I will make a follow-up PR. |
I'll do the optimisations too, after this PR. One thing I haven't done also is check how fast this is compared to the baseline for a higher token count, I think for f16 it will a quite a lot higher than 16. However that will really blow up the compile time. |
73c3bbf
to
bee77df
Compare
@JohannesGaessler please take a look when you get time |
On my system this does not work to reduce the compilation time. The bulk of the compilation is still done for |
Add support for mul_mat_id for bs < 16
311d34d
to
e6528a3
Compare
e6528a3
to
43d2bc4
Compare
I divided the templates according to ncols_dst, following are the compile times for
|
43d2bc4
to
bc12fd1
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you forgot to add the new files generated by generate_cu_files.py
to git.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you, this is a pretty good speedup.
Performance changes
GPU | Model | Microbatch size | Test | t/s b0d5299 | t/s 184e42d | Speedup |
---|---|---|---|---|---|---|
RTX 3090 | granitemoe 3B F16 | 1 | pp512 | 221.91 | 221.29 | 1.00 |
RTX 3090 | granitemoe 3B F16 | 2 | pp512 | 121.58 | 366.75 | 3.02 |
RTX 3090 | granitemoe 3B F16 | 3 | pp512 | 162.26 | 504.86 | 3.11 |
RTX 3090 | granitemoe 3B F16 | 4 | pp512 | 199.52 | 626.74 | 3.14 |
RTX 3090 | granitemoe 3B F16 | 5 | pp512 | 231.65 | 740.41 | 3.20 |
RTX 3090 | granitemoe 3B F16 | 6 | pp512 | 268.74 | 862.74 | 3.21 |
RTX 3090 | granitemoe 3B F16 | 7 | pp512 | 300.47 | 966.34 | 3.22 |
RTX 3090 | granitemoe 3B F16 | 8 | pp512 | 335.04 | 1089.98 | 3.25 |
RTX 3090 | granitemoe 3B F16 | 9 | pp512 | 355.31 | 1130.60 | 3.18 |
RTX 3090 | granitemoe 3B F16 | 10 | pp512 | 383.52 | 1226.65 | 3.20 |
RTX 3090 | granitemoe 3B F16 | 11 | pp512 | 416.42 | 1323.06 | 3.18 |
RTX 3090 | granitemoe 3B F16 | 12 | pp512 | 447.67 | 1419.12 | 3.17 |
RTX 3090 | granitemoe 3B F16 | 13 | pp512 | 470.12 | 1491.19 | 3.17 |
RTX 3090 | granitemoe 3B F16 | 14 | pp512 | 499.06 | 1579.86 | 3.17 |
RTX 3090 | granitemoe 3B F16 | 15 | pp512 | 527.34 | 1649.31 | 3.13 |
RTX 3090 | granitemoe 3B F16 | 16 | pp512 | 555.90 | 1744.24 | 3.14 |
RTX 4090 | granitemoe 3B F16 | 1 | pp512 | 313.27 | 313.36 | 1.00 |
RTX 4090 | granitemoe 3B F16 | 2 | pp512 | 123.59 | 444.67 | 3.60 |
RTX 4090 | granitemoe 3B F16 | 3 | pp512 | 149.77 | 609.81 | 4.07 |
RTX 4090 | granitemoe 3B F16 | 4 | pp512 | 180.45 | 756.26 | 4.19 |
RTX 4090 | granitemoe 3B F16 | 5 | pp512 | 206.47 | 890.42 | 4.31 |
RTX 4090 | granitemoe 3B F16 | 6 | pp512 | 235.36 | 1035.59 | 4.40 |
RTX 4090 | granitemoe 3B F16 | 7 | pp512 | 261.88 | 1156.39 | 4.42 |
RTX 4090 | granitemoe 3B F16 | 8 | pp512 | 289.49 | 1300.09 | 4.49 |
RTX 4090 | granitemoe 3B F16 | 9 | pp512 | 306.20 | 1363.93 | 4.45 |
RTX 4090 | granitemoe 3B F16 | 10 | pp512 | 333.31 | 1486.41 | 4.46 |
RTX 4090 | granitemoe 3B F16 | 11 | pp512 | 352.67 | 1605.19 | 4.55 |
RTX 4090 | granitemoe 3B F16 | 12 | pp512 | 375.57 | 1720.24 | 4.58 |
RTX 4090 | granitemoe 3B F16 | 13 | pp512 | 394.11 | 1814.92 | 4.61 |
RTX 4090 | granitemoe 3B F16 | 14 | pp512 | 417.75 | 1927.71 | 4.61 |
RTX 4090 | granitemoe 3B F16 | 15 | pp512 | 435.70 | 2029.94 | 4.66 |
RTX 4090 | granitemoe 3B F16 | 16 | pp512 | 457.82 | 2154.04 | 4.71 |
2x RTX 4090 | deepseek2 16B F16 | 1 | pp512 | 273.92 | 273.86 | 1.00 |
2x RTX 4090 | deepseek2 16B F16 | 2 | pp512 | 131.71 | 361.76 | 2.75 |
2x RTX 4090 | deepseek2 16B F16 | 3 | pp512 | 164.75 | 458.22 | 2.78 |
2x RTX 4090 | deepseek2 16B F16 | 4 | pp512 | 193.37 | 554.75 | 2.87 |
2x RTX 4090 | deepseek2 16B F16 | 5 | pp512 | 222.33 | 635.67 | 2.86 |
2x RTX 4090 | deepseek2 16B F16 | 6 | pp512 | 250.58 | 722.15 | 2.88 |
2x RTX 4090 | deepseek2 16B F16 | 7 | pp512 | 273.81 | 789.98 | 2.89 |
2x RTX 4090 | deepseek2 16B F16 | 8 | pp512 | 296.32 | 884.05 | 2.98 |
2x RTX 4090 | deepseek2 16B F16 | 9 | pp512 | 310.39 | 903.10 | 2.91 |
2x RTX 4090 | deepseek2 16B F16 | 10 | pp512 | 341.71 | 992.85 | 2.91 |
2x RTX 4090 | deepseek2 16B F16 | 11 | pp512 | 364.37 | 1064.10 | 2.92 |
2x RTX 4090 | deepseek2 16B F16 | 12 | pp512 | 392.72 | 1132.44 | 2.88 |
2x RTX 4090 | deepseek2 16B F16 | 13 | pp512 | 396.68 | 1154.09 | 2.91 |
2x RTX 4090 | deepseek2 16B F16 | 14 | pp512 | 427.56 | 1237.19 | 2.89 |
2x RTX 4090 | deepseek2 16B F16 | 15 | pp512 | 445.56 | 1293.89 | 2.90 |
2x RTX 4090 | deepseek2 16B F16 | 16 | pp512 | 458.10 | 1377.46 | 3.01 |
fda19db
to
9db0ac4
Compare
* CUDA: Add mul_mat_id support the mmf Add support for mul_mat_id for bs < 16 * Review: use warp_size, fix should_use_mmf condition * Launch one block per expert, stride along n_expert_used * templatize mul_mat_id * Pad shmem to 16 bytes, add helper function mul_mat_f_switch_ids * Reduce compile times by dividing mmf into f16, bf16 and f32 variants * Divide mmf by ncols_dst * Add missing files * Fix MUSA/HIP builds
Add support for mul_mat_id for bs < 16. This kernel works by calculating activations for each token for each expert, except it filters expert which are not used in {n_expert_used, token}, so does more compute but overall is better than the existing path for bs <= 16
on a 4090
Also, the normal mmf path is not affected