66// nbatch_K == number of K columns to load in parallel for KQ calculation
77
88// TODO optimize kernel parameters for FP16 NVIDIA (P100)
9- // TODO optimize kernel parameters for head sizes 40, 72, 80, 96, 112
9+ // TODO optimize kernel parameters for head sizes 40, 80, 96, 112
1010
1111// The ROCm compiler cannot handle templating in __launch_bounds__.
1212// As a workaround, define a macro to package the kernel parameters as uint32_t:
@@ -32,12 +32,6 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv
3232 GGML_CUDA_FATTN_TILE_CONFIG_CASE ( 64 , 64 , 16 , 256 , 2 , 64 , 64 )
3333 GGML_CUDA_FATTN_TILE_CONFIG_CASE ( 64 , 64 , 32 , 256 , 2 , 64 , 64 )
3434
35- GGML_CUDA_FATTN_TILE_CONFIG_CASE ( 72 , 72 , 2 , 64 , 2 , 64 , 72 )
36- GGML_CUDA_FATTN_TILE_CONFIG_CASE ( 72 , 72 , 4 , 128 , 2 , 64 , 72 )
37- GGML_CUDA_FATTN_TILE_CONFIG_CASE ( 72 , 72 , 8 , 256 , 2 , 64 , 72 )
38- GGML_CUDA_FATTN_TILE_CONFIG_CASE ( 72 , 72 , 16 , 256 , 2 , 64 , 72 )
39- GGML_CUDA_FATTN_TILE_CONFIG_CASE ( 72 , 72 , 32 , 256 , 2 , 64 , 72 )
40-
4135 GGML_CUDA_FATTN_TILE_CONFIG_CASE ( 80 , 80 , 2 , 64 , 2 , 64 , 40 )
4236 GGML_CUDA_FATTN_TILE_CONFIG_CASE ( 80 , 80 , 4 , 128 , 2 , 64 , 40 )
4337 GGML_CUDA_FATTN_TILE_CONFIG_CASE ( 80 , 80 , 8 , 256 , 2 , 64 , 40 )
@@ -86,12 +80,6 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv
8680 GGML_CUDA_FATTN_TILE_CONFIG_CASE ( 64 , 64 , 16 , 128 , 3 , 64 , 64 )
8781 GGML_CUDA_FATTN_TILE_CONFIG_CASE ( 64 , 64 , 32 , 256 , 2 , 64 , 64 )
8882
89- GGML_CUDA_FATTN_TILE_CONFIG_CASE ( 72 , 72 , 2 , 64 , 2 , 32 , 72 )
90- GGML_CUDA_FATTN_TILE_CONFIG_CASE ( 72 , 72 , 4 , 128 , 2 , 32 , 72 )
91- GGML_CUDA_FATTN_TILE_CONFIG_CASE ( 72 , 72 , 8 , 256 , 2 , 32 , 72 )
92- GGML_CUDA_FATTN_TILE_CONFIG_CASE ( 72 , 72 , 16 , 256 , 2 , 32 , 72 )
93- GGML_CUDA_FATTN_TILE_CONFIG_CASE ( 72 , 72 , 32 , 256 , 2 , 32 , 72 )
94-
9583 GGML_CUDA_FATTN_TILE_CONFIG_CASE ( 80 , 80 , 2 , 64 , 2 , 32 , 40 )
9684 GGML_CUDA_FATTN_TILE_CONFIG_CASE ( 80 , 80 , 4 , 128 , 2 , 32 , 40 )
9785 GGML_CUDA_FATTN_TILE_CONFIG_CASE ( 80 , 80 , 8 , 256 , 2 , 32 , 40 )
@@ -142,13 +130,6 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am
142130 GGML_CUDA_FATTN_TILE_CONFIG_CASE ( 64 , 64 , 32 , 256 , 2 , 64 , 64 )
143131 GGML_CUDA_FATTN_TILE_CONFIG_CASE ( 64 , 64 , 64 , 256 , 2 , 64 , 64 )
144132
145- GGML_CUDA_FATTN_TILE_CONFIG_CASE ( 72 , 72 , 2 , 64 , 2 , 32 , 72 )
146- GGML_CUDA_FATTN_TILE_CONFIG_CASE ( 72 , 72 , 4 , 128 , 2 , 32 , 72 )
147- GGML_CUDA_FATTN_TILE_CONFIG_CASE ( 72 , 72 , 8 , 256 , 2 , 32 , 72 )
148- GGML_CUDA_FATTN_TILE_CONFIG_CASE ( 72 , 72 , 16 , 256 , 2 , 32 , 72 )
149- GGML_CUDA_FATTN_TILE_CONFIG_CASE ( 72 , 72 , 32 , 256 , 2 , 32 , 72 )
150- GGML_CUDA_FATTN_TILE_CONFIG_CASE ( 72 , 72 , 64 , 256 , 2 , 32 , 72 )
151-
152133 GGML_CUDA_FATTN_TILE_CONFIG_CASE ( 80 , 80 , 2 , 64 , 2 , 32 , 40 )
153134 GGML_CUDA_FATTN_TILE_CONFIG_CASE ( 80 , 80 , 4 , 128 , 2 , 32 , 40 )
154135 GGML_CUDA_FATTN_TILE_CONFIG_CASE ( 80 , 80 , 8 , 256 , 2 , 32 , 40 )
@@ -204,13 +185,6 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am
204185 GGML_CUDA_FATTN_TILE_CONFIG_CASE ( 64 , 64 , 32 , 128 , 4 , 64 , 64 )
205186 GGML_CUDA_FATTN_TILE_CONFIG_CASE ( 64 , 64 , 64 , 128 , 5 , 64 , 64 )
206187
207- GGML_CUDA_FATTN_TILE_CONFIG_CASE ( 72 , 72 , 2 , 64 , 2 , 32 , 72 )
208- GGML_CUDA_FATTN_TILE_CONFIG_CASE ( 72 , 72 , 4 , 128 , 2 , 32 , 72 )
209- GGML_CUDA_FATTN_TILE_CONFIG_CASE ( 72 , 72 , 8 , 256 , 2 , 32 , 72 )
210- GGML_CUDA_FATTN_TILE_CONFIG_CASE ( 72 , 72 , 16 , 256 , 2 , 32 , 72 )
211- GGML_CUDA_FATTN_TILE_CONFIG_CASE ( 72 , 72 , 32 , 256 , 2 , 32 , 72 )
212- GGML_CUDA_FATTN_TILE_CONFIG_CASE ( 72 , 72 , 64 , 256 , 2 , 32 , 72 )
213-
214188 GGML_CUDA_FATTN_TILE_CONFIG_CASE ( 80 , 80 , 2 , 64 , 2 , 32 , 40 )
215189 GGML_CUDA_FATTN_TILE_CONFIG_CASE ( 80 , 80 , 4 , 128 , 2 , 32 , 40 )
216190 GGML_CUDA_FATTN_TILE_CONFIG_CASE ( 80 , 80 , 8 , 256 , 2 , 32 , 40 )
@@ -749,7 +723,7 @@ static __global__ void flash_attn_tile(
749723
750724 if (
751725#ifdef GGML_USE_WMMA_FATTN
752- (ncols2 != 1 && DV != 40 && DV != 72 && DV != 512 ) ||
726+ (ncols2 != 1 && DV != 40 && DV != 512 ) ||
753727#endif // GGML_USE_WMMA_FATTN
754728 (use_logit_softcap && !(DV == 128 || DV == 256 ))
755729 ) {
@@ -1224,7 +1198,6 @@ void ggml_cuda_flash_attn_ext_tile(ggml_backend_cuda_context & ctx, ggml_tensor
12241198
12251199extern DECL_FATTN_TILE_CASE ( 40 , 40 );
12261200extern DECL_FATTN_TILE_CASE ( 64 , 64 );
1227- extern DECL_FATTN_TILE_CASE ( 72 , 72 );
12281201extern DECL_FATTN_TILE_CASE ( 80 , 80 );
12291202extern DECL_FATTN_TILE_CASE ( 96 , 96 );
12301203extern DECL_FATTN_TILE_CASE (112 , 112 );
0 commit comments