66// nbatch_K == number of K columns to load in parallel for KQ calculation
77
88// TODO optimize kernel parameters for FP16 NVIDIA (P100)
9- // TODO optimize kernel parameters for head sizes 40, 80, 96, 112
9+ // TODO optimize kernel parameters for head sizes 40, 72, 80, 96, 112
1010
1111// The ROCm compiler cannot handle templating in __launch_bounds__.
1212// As a workaround, define a macro to package the kernel parameters as uint32_t:
@@ -32,6 +32,12 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv
3232 GGML_CUDA_FATTN_TILE_CONFIG_CASE ( 64 , 64 , 16 , 256 , 2 , 64 , 64 )
3333 GGML_CUDA_FATTN_TILE_CONFIG_CASE ( 64 , 64 , 32 , 256 , 2 , 64 , 64 )
3434
35+ GGML_CUDA_FATTN_TILE_CONFIG_CASE ( 72 , 72 , 2 , 64 , 2 , 64 , 72 )
36+ GGML_CUDA_FATTN_TILE_CONFIG_CASE ( 72 , 72 , 4 , 128 , 2 , 64 , 72 )
37+ GGML_CUDA_FATTN_TILE_CONFIG_CASE ( 72 , 72 , 8 , 256 , 2 , 64 , 72 )
38+ GGML_CUDA_FATTN_TILE_CONFIG_CASE ( 72 , 72 , 16 , 256 , 2 , 64 , 72 )
39+ GGML_CUDA_FATTN_TILE_CONFIG_CASE ( 72 , 72 , 32 , 256 , 2 , 64 , 72 )
40+
3541 GGML_CUDA_FATTN_TILE_CONFIG_CASE ( 80 , 80 , 2 , 64 , 2 , 64 , 40 )
3642 GGML_CUDA_FATTN_TILE_CONFIG_CASE ( 80 , 80 , 4 , 128 , 2 , 64 , 40 )
3743 GGML_CUDA_FATTN_TILE_CONFIG_CASE ( 80 , 80 , 8 , 256 , 2 , 64 , 40 )
@@ -80,6 +86,12 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv
8086 GGML_CUDA_FATTN_TILE_CONFIG_CASE ( 64 , 64 , 16 , 128 , 3 , 64 , 64 )
8187 GGML_CUDA_FATTN_TILE_CONFIG_CASE ( 64 , 64 , 32 , 256 , 2 , 64 , 64 )
8288
89+ GGML_CUDA_FATTN_TILE_CONFIG_CASE ( 72 , 72 , 2 , 64 , 2 , 32 , 72 )
90+ GGML_CUDA_FATTN_TILE_CONFIG_CASE ( 72 , 72 , 4 , 128 , 2 , 32 , 72 )
91+ GGML_CUDA_FATTN_TILE_CONFIG_CASE ( 72 , 72 , 8 , 256 , 2 , 32 , 72 )
92+ GGML_CUDA_FATTN_TILE_CONFIG_CASE ( 72 , 72 , 16 , 256 , 2 , 32 , 72 )
93+ GGML_CUDA_FATTN_TILE_CONFIG_CASE ( 72 , 72 , 32 , 256 , 2 , 32 , 72 )
94+
8395 GGML_CUDA_FATTN_TILE_CONFIG_CASE ( 80 , 80 , 2 , 64 , 2 , 32 , 40 )
8496 GGML_CUDA_FATTN_TILE_CONFIG_CASE ( 80 , 80 , 4 , 128 , 2 , 32 , 40 )
8597 GGML_CUDA_FATTN_TILE_CONFIG_CASE ( 80 , 80 , 8 , 256 , 2 , 32 , 40 )
@@ -130,6 +142,13 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am
130142 GGML_CUDA_FATTN_TILE_CONFIG_CASE ( 64 , 64 , 32 , 256 , 2 , 64 , 64 )
131143 GGML_CUDA_FATTN_TILE_CONFIG_CASE ( 64 , 64 , 64 , 256 , 2 , 64 , 64 )
132144
145+ GGML_CUDA_FATTN_TILE_CONFIG_CASE ( 72 , 72 , 2 , 64 , 2 , 32 , 72 )
146+ GGML_CUDA_FATTN_TILE_CONFIG_CASE ( 72 , 72 , 4 , 128 , 2 , 32 , 72 )
147+ GGML_CUDA_FATTN_TILE_CONFIG_CASE ( 72 , 72 , 8 , 256 , 2 , 32 , 72 )
148+ GGML_CUDA_FATTN_TILE_CONFIG_CASE ( 72 , 72 , 16 , 256 , 2 , 32 , 72 )
149+ GGML_CUDA_FATTN_TILE_CONFIG_CASE ( 72 , 72 , 32 , 256 , 2 , 32 , 72 )
150+ GGML_CUDA_FATTN_TILE_CONFIG_CASE ( 72 , 72 , 64 , 256 , 2 , 32 , 72 )
151+
133152 GGML_CUDA_FATTN_TILE_CONFIG_CASE ( 80 , 80 , 2 , 64 , 2 , 32 , 40 )
134153 GGML_CUDA_FATTN_TILE_CONFIG_CASE ( 80 , 80 , 4 , 128 , 2 , 32 , 40 )
135154 GGML_CUDA_FATTN_TILE_CONFIG_CASE ( 80 , 80 , 8 , 256 , 2 , 32 , 40 )
@@ -185,6 +204,13 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am
185204 GGML_CUDA_FATTN_TILE_CONFIG_CASE ( 64 , 64 , 32 , 128 , 4 , 64 , 64 )
186205 GGML_CUDA_FATTN_TILE_CONFIG_CASE ( 64 , 64 , 64 , 128 , 5 , 64 , 64 )
187206
207+ GGML_CUDA_FATTN_TILE_CONFIG_CASE ( 72 , 72 , 2 , 64 , 2 , 32 , 72 )
208+ GGML_CUDA_FATTN_TILE_CONFIG_CASE ( 72 , 72 , 4 , 128 , 2 , 32 , 72 )
209+ GGML_CUDA_FATTN_TILE_CONFIG_CASE ( 72 , 72 , 8 , 256 , 2 , 32 , 72 )
210+ GGML_CUDA_FATTN_TILE_CONFIG_CASE ( 72 , 72 , 16 , 256 , 2 , 32 , 72 )
211+ GGML_CUDA_FATTN_TILE_CONFIG_CASE ( 72 , 72 , 32 , 256 , 2 , 32 , 72 )
212+ GGML_CUDA_FATTN_TILE_CONFIG_CASE ( 72 , 72 , 64 , 256 , 2 , 32 , 72 )
213+
188214 GGML_CUDA_FATTN_TILE_CONFIG_CASE ( 80 , 80 , 2 , 64 , 2 , 32 , 40 )
189215 GGML_CUDA_FATTN_TILE_CONFIG_CASE ( 80 , 80 , 4 , 128 , 2 , 32 , 40 )
190216 GGML_CUDA_FATTN_TILE_CONFIG_CASE ( 80 , 80 , 8 , 256 , 2 , 32 , 40 )
@@ -723,7 +749,7 @@ static __global__ void flash_attn_tile(
723749
724750 if (
725751#ifdef GGML_USE_WMMA_FATTN
726- (ncols2 != 1 && DV != 40 && DV != 512 ) ||
752+ (ncols2 != 1 && DV != 40 && DV != 72 && DV != 512 ) ||
727753#endif // GGML_USE_WMMA_FATTN
728754 (use_logit_softcap && !(DV == 128 || DV == 256 ))
729755 ) {
@@ -1198,6 +1224,7 @@ void ggml_cuda_flash_attn_ext_tile(ggml_backend_cuda_context & ctx, ggml_tensor
11981224
11991225extern DECL_FATTN_TILE_CASE ( 40 , 40 );
12001226extern DECL_FATTN_TILE_CASE ( 64 , 64 );
1227+ extern DECL_FATTN_TILE_CASE ( 72 , 72 );
12011228extern DECL_FATTN_TILE_CASE ( 80 , 80 );
12021229extern DECL_FATTN_TILE_CASE ( 96 , 96 );
12031230extern DECL_FATTN_TILE_CASE (112 , 112 );
0 commit comments