Skip to content

Commit fabcce9

Browse files
siddartha-REorca-zhang
authored andcommitted
Add support for Deepseek-R1 flash attention
1 parent 2e02087 commit fabcce9

File tree

1 file changed

+59
-3
lines changed

1 file changed

+59
-3
lines changed

ggml/src/ggml-cuda/fattn.cu

Lines changed: 59 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,66 @@
1010

1111
template <int D, int ncols2>
1212
static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
13-
const ggml_tensor * Q = dst->src[0];
13+
const ggml_tensor * KQV = dst;
14+
const ggml_tensor * Q = dst->src[0];
15+
16+
const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV);
1417

15-
if (Q->ne[1] <= 8/ncols2) {
16-
ggml_cuda_flash_attn_ext_mma_f16_case<D, 8/ncols2, ncols2>(ctx, dst);
18+
if (prec != GGML_PREC_DEFAULT) {
19+
if (Q->ne[1] <= 32 || Q->ne[0] > 128) {
20+
constexpr int cols_per_block = 16;
21+
switch (Q->ne[0]) {
22+
case 64:
23+
ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, float>(ctx, dst);
24+
break;
25+
case 80:
26+
ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, float>(ctx, dst);
27+
break;
28+
case 96:
29+
ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, float>(ctx, dst);
30+
break;
31+
case 112:
32+
ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, float>(ctx, dst);
33+
break;
34+
case 128:
35+
ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst);
36+
break;
37+
case 192:
38+
ggml_cuda_flash_attn_ext_wmma_f16_case<192, cols_per_block, float>(ctx, dst);
39+
break;
40+
case 256:
41+
ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, float>(ctx, dst);
42+
break;
43+
default:
44+
GGML_ABORT("fatal error");
45+
break;
46+
}
47+
} else {
48+
constexpr int cols_per_block = 32;
49+
switch (Q->ne[0]) {
50+
case 64:
51+
ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, float>(ctx, dst);
52+
break;
53+
case 80:
54+
ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, float>(ctx, dst);
55+
break;
56+
case 96:
57+
ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, float>(ctx, dst);
58+
break;
59+
case 112:
60+
ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, float>(ctx, dst);
61+
break;
62+
case 128:
63+
ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst);
64+
break;
65+
// case 256:
66+
// ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst);
67+
// break;
68+
default:
69+
GGML_ABORT("fatal error");
70+
break;
71+
}
72+
}
1773
return;
1874
}
1975

0 commit comments

Comments
 (0)