Skip to content

Commit 622cd01

Browse files
authored
ggml: CUDA: add head size 72 for flash-attn (#16962)
1 parent 070ff4d commit 622cd01

File tree

5 files changed

+44
-5
lines changed

5 files changed

+44
-5
lines changed

ggml/src/ggml-cuda/fattn-tile.cu

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,10 @@ void ggml_cuda_flash_attn_ext_tile(ggml_backend_cuda_context & ctx, ggml_tensor
1414
GGML_ASSERT(V->ne[0] == K->ne[0]);
1515
ggml_cuda_flash_attn_ext_tile_case< 64, 64>(ctx, dst);
1616
} break;
17+
case 72: {
18+
GGML_ASSERT(V->ne[0] == K->ne[0]);
19+
ggml_cuda_flash_attn_ext_tile_case< 72, 72>(ctx, dst);
20+
} break;
1721
case 80: {
1822
GGML_ASSERT(V->ne[0] == K->ne[0]);
1923
ggml_cuda_flash_attn_ext_tile_case< 80, 80>(ctx, dst);

ggml/src/ggml-cuda/fattn-tile.cuh

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
// nbatch_K == number of K columns to load in parallel for KQ calculation
77

88
// TODO optimize kernel parameters for FP16 NVIDIA (P100)
9-
// TODO optimize kernel parameters for head sizes 40, 80, 96, 112
9+
// TODO optimize kernel parameters for head sizes 40, 72, 80, 96, 112
1010

1111
// The ROCm compiler cannot handle templating in __launch_bounds__.
1212
// As a workaround, define a macro to package the kernel parameters as uint32_t:
@@ -32,6 +32,12 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv
3232
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 16, 256, 2, 64, 64)
3333
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 32, 256, 2, 64, 64)
3434

35+
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 2, 64, 2, 64, 72)
36+
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 4, 128, 2, 64, 72)
37+
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 8, 256, 2, 64, 72)
38+
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 16, 256, 2, 64, 72)
39+
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 32, 256, 2, 64, 72)
40+
3541
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 2, 64, 2, 64, 40)
3642
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 4, 128, 2, 64, 40)
3743
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 8, 256, 2, 64, 40)
@@ -80,6 +86,12 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv
8086
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 16, 128, 3, 64, 64)
8187
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 32, 256, 2, 64, 64)
8288

89+
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 2, 64, 2, 32, 72)
90+
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 4, 128, 2, 32, 72)
91+
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 8, 256, 2, 32, 72)
92+
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 16, 256, 2, 32, 72)
93+
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 32, 256, 2, 32, 72)
94+
8395
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 2, 64, 2, 32, 40)
8496
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 4, 128, 2, 32, 40)
8597
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 8, 256, 2, 32, 40)
@@ -130,6 +142,13 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am
130142
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 32, 256, 2, 64, 64)
131143
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 64, 256, 2, 64, 64)
132144

145+
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 2, 64, 2, 32, 72)
146+
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 4, 128, 2, 32, 72)
147+
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 8, 256, 2, 32, 72)
148+
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 16, 256, 2, 32, 72)
149+
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 32, 256, 2, 32, 72)
150+
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 64, 256, 2, 32, 72)
151+
133152
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 2, 64, 2, 32, 40)
134153
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 4, 128, 2, 32, 40)
135154
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 8, 256, 2, 32, 40)
@@ -185,6 +204,13 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am
185204
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 32, 128, 4, 64, 64)
186205
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 64, 128, 5, 64, 64)
187206

207+
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 2, 64, 2, 32, 72)
208+
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 4, 128, 2, 32, 72)
209+
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 8, 256, 2, 32, 72)
210+
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 16, 256, 2, 32, 72)
211+
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 32, 256, 2, 32, 72)
212+
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 64, 256, 2, 32, 72)
213+
188214
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 2, 64, 2, 32, 40)
189215
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 4, 128, 2, 32, 40)
190216
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 8, 256, 2, 32, 40)
@@ -723,7 +749,7 @@ static __global__ void flash_attn_tile(
723749

724750
if (
725751
#ifdef GGML_USE_WMMA_FATTN
726-
(ncols2 != 1 && DV != 40 && DV != 512) ||
752+
(ncols2 != 1 && DV != 40 && DV != 72 && DV != 512) ||
727753
#endif // GGML_USE_WMMA_FATTN
728754
(use_logit_softcap && !(DV == 128 || DV == 256))
729755
) {
@@ -1198,6 +1224,7 @@ void ggml_cuda_flash_attn_ext_tile(ggml_backend_cuda_context & ctx, ggml_tensor
11981224

11991225
extern DECL_FATTN_TILE_CASE( 40, 40);
12001226
extern DECL_FATTN_TILE_CASE( 64, 64);
1227+
extern DECL_FATTN_TILE_CASE( 72, 72);
12011228
extern DECL_FATTN_TILE_CASE( 80, 80);
12021229
extern DECL_FATTN_TILE_CASE( 96, 96);
12031230
extern DECL_FATTN_TILE_CASE(112, 112);

ggml/src/ggml-cuda/fattn.cu

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
223223
switch (K->ne[0]) {
224224
case 40:
225225
case 64:
226+
case 72:
226227
case 80:
227228
case 96:
228229
case 128:
@@ -275,7 +276,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
275276
const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % 64 == 0 && K->ne[1] % FATTN_KQ_STRIDE == 0;
276277

277278
// If Turing tensor cores available, use them:
278-
if (turing_mma_available(cc) && K->ne[1] % FATTN_KQ_STRIDE == 0 && Q->ne[0] != 40) {
279+
if (turing_mma_available(cc) && K->ne[1] % FATTN_KQ_STRIDE == 0 && Q->ne[0] != 40 && Q->ne[0] != 72) {
279280
if (can_use_vector_kernel) {
280281
if (!ggml_is_quantized(K->type) && !ggml_is_quantized(V->type)) {
281282
if (cc >= GGML_CUDA_CC_ADA_LOVELACE && Q->ne[1] == 1 && Q->ne[3] == 1 && !(gqa_ratio > 4 && K->ne[1] >= 8192)) {
@@ -301,7 +302,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
301302
}
302303

303304
// Use the WMMA kernel if possible:
304-
if (ggml_cuda_should_use_wmma_fattn(cc) && K->ne[1] % FATTN_KQ_STRIDE == 0 && Q->ne[0] != 40 && Q->ne[0] != 576) {
305+
if (ggml_cuda_should_use_wmma_fattn(cc) && K->ne[1] % FATTN_KQ_STRIDE == 0 && Q->ne[0] != 40 && Q->ne[0] != 72 && Q->ne[0] != 576) {
305306
if (can_use_vector_kernel && Q->ne[1] <= 2) {
306307
return BEST_FATTN_KERNEL_VEC;
307308
}
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
2+
3+
#include "../fattn-tile.cuh"
4+
5+
DECL_FATTN_TILE_CASE(72, 72);

ggml/src/ggml-cuda/template-instances/generate_cu_files.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from glob import glob
44
import os
55

6-
HEAD_SIZES_KQ = [40, 64, 80, 96, 112, 128, 256, 576]
6+
HEAD_SIZES_KQ = [40, 64, 72, 80, 96, 112, 128, 256, 576]
77

88
TYPES_KV = ["GGML_TYPE_F16", "GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0"]
99

@@ -81,6 +81,8 @@ def get_short_name(long_quant_name):
8181
for head_size_kq in HEAD_SIZES_KQ:
8282
if head_size_kq == 40:
8383
continue
84+
if head_size_kq == 72:
85+
continue
8486
if head_size_kq != 576 and ncols2 == 16:
8587
continue
8688
if head_size_kq == 576 and ncols2 != 16:

0 commit comments

Comments
 (0)