Skip to content

Commit 30536ee

Browse files
ikawrakowIwan Kawrakow
andauthored
FlashMLA-3 for DeepSeek models on CUDA (#386)
* CUDA WIP: support for FlashMLA-3 * Much better The issue was that I did not change the number of warps used for 3D matrix multiplications (wk_b * kv_cache, MoE), so we ended up using 4 warps for TG. By going to 1 warp in these cases, we get a significant boost in TG performance (tested with DeepSeek-Lite) * Sadly, the previous commit was wrong * Finalizing * Also add these * Minor * Minor tweak --------- Co-authored-by: Iwan Kawrakow <[email protected]>
1 parent 17c6fc6 commit 30536ee

File tree

5 files changed

+1798
-45
lines changed

5 files changed

+1798
-45
lines changed

ggml/src/ggml-cuda.cu

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3587,6 +3587,11 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
35873587
return (op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16) ||
35883588
(op->src[1]->type == GGML_TYPE_Q8_0 && op->src[2]->type == GGML_TYPE_Q8_0);
35893589
}
3590+
if (op->src[1]->ne[0] == 576 && op->src[2]->ne[0] == 512) {
3591+
const int cc = ggml_cuda_info().devices[cuda_ctx->device].cc;
3592+
int gqa = op->src[0]->ne[2]/op->src[1]->ne[2];
3593+
return (new_mma_available(cc) && cc >= CC_AMPERE && op->src[3] && gqa%16 == 0);
3594+
}
35903595
if (op->src[1]->ne[0] > 256) {
35913596
return false;
35923597
}

0 commit comments

Comments
 (0)