@@ -54,17 +54,8 @@ static __global__ void flash_attn_tile_ext_f16(
5454
5555 const int stride_KV2 = nb11 / sizeof (half2);
5656
57- half slopeh = __float2half (1 .0f );
58-
59- // ALiBi
60- if (max_bias > 0 .0f ) {
61- const uint32_t h = blockIdx .y ;
62-
63- const float base = h < n_head_log2 ? m0 : m1;
64- const int exph = h < n_head_log2 ? h + 1 : 2 *(h - n_head_log2) + 1 ;
65-
66- slopeh = __float2half (powf (base, exph));
67- }
57+ const float slopef = get_alibi_slope (max_bias, blockIdx .y , n_head_log2, m0, m1);
58+ const half slopeh = __float2half (slopef);
6859
6960 static_assert (D % (2 *WARP_SIZE) == 0 , " D not divisible by 2*WARP_SIZE == 64." );
7061
@@ -272,124 +263,50 @@ static __global__ void flash_attn_tile_ext_f16(
272263#endif // FP16_AVAILABLE
273264}
274265
275- template <int D, int cols_per_block, int parallel_blocks> void launch_fattn_tile_f16 (
276- const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, ggml_tensor * KQV, const ggml_tensor * mask,
277- ggml_cuda_pool & pool, cudaStream_t main_stream
278- ) {
279- ggml_cuda_pool_alloc<float > dst_tmp (pool);
280- ggml_cuda_pool_alloc<float2 > dst_tmp_meta (pool);
281-
282- if (parallel_blocks > 1 ) {
283- dst_tmp.alloc (parallel_blocks*ggml_nelements (KQV));
284- dst_tmp_meta.alloc (parallel_blocks*ggml_nrows (KQV));
285- }
286-
287- constexpr int nwarps = 8 ;
288- const dim3 block_dim (WARP_SIZE, nwarps, 1 );
289- const dim3 blocks_num (parallel_blocks*((Q->ne [1 ] + cols_per_block - 1 ) / cols_per_block), Q->ne [2 ], Q->ne [3 ]);
290- const int shmem = 0 ;
291-
292- float scale = 1 .0f ;
293- float max_bias = 0 .0f ;
294-
295- memcpy (&scale, (float *) KQV->op_params + 0 , sizeof (float ));
296- memcpy (&max_bias, (float *) KQV->op_params + 1 , sizeof (float ));
297-
298- const uint32_t n_head = Q->ne [2 ];
299- const uint32_t n_head_log2 = 1u << (uint32_t ) floorf (log2f ((float ) n_head));
300-
301- const float m0 = powf (2 .0f , -(max_bias ) / n_head_log2);
302- const float m1 = powf (2 .0f , -(max_bias / 2 .0f ) / n_head_log2);
303-
304- flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks>
305- <<<blocks_num, block_dim, shmem, main_stream>>> (
306- (const char *) Q->data ,
307- (const char *) K->data ,
308- (const char *) V->data ,
309- mask ? ((const char *) mask->data ) : nullptr ,
310- parallel_blocks == 1 ? (float *) KQV->data : dst_tmp.ptr , dst_tmp_meta.ptr ,
311- scale, max_bias, m0, m1, n_head_log2,
312- Q->ne [0 ], Q->ne [1 ], Q->ne [2 ], Q->ne [3 ],
313- K->ne [0 ], K->ne [1 ], K->ne [2 ], K->ne [3 ],
314- mask ? mask->ne [1 ] : 0 , mask ? mask->nb [1 ] : 0 ,
315- Q->nb [1 ], Q->nb [2 ], Q->nb [3 ],
316- K->nb [1 ], K->nb [2 ], K->nb [3 ],
317- KQV->ne [0 ], KQV->ne [1 ], KQV->ne [2 ], KQV->ne [3 ]
318- );
319- CUDA_CHECK (cudaGetLastError ());
320-
321- if (parallel_blocks == 1 ) {
322- return ;
266+ template <int cols_per_block, int parallel_blocks>
267+ void launch_fattn_tile_f16_64_128 (ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
268+ const ggml_tensor * Q = dst->src [0 ];
269+ switch (Q->ne [0 ]) {
270+ case 64 : {
271+ constexpr int D = 64 ;
272+ constexpr int nwarps = 8 ;
273+ fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks>;
274+ launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block);
275+ } break ;
276+ case 128 : {
277+ constexpr int D = 128 ;
278+ constexpr int nwarps = 8 ;
279+ fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks>;
280+ launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block);
281+ } break ;
282+ default : {
283+ GGML_ASSERT (false && " FlashAttention without tensor cores only supports head sizes 64 and 128." );
284+ } break ;
323285 }
324-
325- const dim3 block_dim_combine (D, 1 , 1 );
326- const dim3 blocks_num_combine (Q->ne [1 ], blocks_num.y , blocks_num.z );
327- const int shmem_combine = 0 ;
328-
329- flash_attn_combine_results<D, parallel_blocks>
330- <<<blocks_num_combine, block_dim_combine, shmem_combine, main_stream>>>
331- (dst_tmp.ptr , dst_tmp_meta.ptr , (float *) KQV->data );
332- CUDA_CHECK (cudaGetLastError ());
333286}
334287
335288void ggml_cuda_flash_attn_ext_tile_f16 (ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
336- const ggml_tensor * Q = dst->src [0 ];
337- const ggml_tensor * K = dst->src [1 ];
338- const ggml_tensor * V = dst->src [2 ];
339-
340- const ggml_tensor * mask = dst->src [3 ];
341-
342- ggml_tensor * KQV = dst;
289+ const ggml_tensor * KQV = dst;
290+ const ggml_tensor * Q = dst->src [0 ];
343291
344292 const int32_t precision = KQV->op_params [2 ];
345293 GGML_ASSERT (precision == GGML_PREC_DEFAULT);
346- GGML_ASSERT (Q->ne [0 ] == 64 || Q->ne [0 ] == 128 && " FlashAttention without tensor cores only supports head sizes 64 and 128." );
347294
348295 if (Q->ne [1 ] <= 16 ) {
349296 constexpr int cols_per_block = 16 ;
350297 constexpr int parallel_blocks = 4 ;
351- switch (Q->ne [0 ]) {
352- case 64 :
353- launch_fattn_tile_f16< 64 , cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool (), ctx.stream ());
354- break ;
355- case 128 :
356- launch_fattn_tile_f16<128 , cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool (), ctx.stream ());
357- break ;
358- default :
359- GGML_ASSERT (false );
360- break ;
361- }
298+ launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks>(ctx, dst);
362299 return ;
363300 }
364301
365302 if (Q->ne [1 ] <= 32 ) {
366303 constexpr int cols_per_block = 32 ;
367304 constexpr int parallel_blocks = 4 ;
368- switch (Q->ne [0 ]) {
369- case 64 :
370- launch_fattn_tile_f16< 64 , cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool (), ctx.stream ());
371- break ;
372- case 128 :
373- launch_fattn_tile_f16<128 , cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool (), ctx.stream ());
374- break ;
375- default :
376- GGML_ASSERT (false );
377- break ;
378- }
305+ launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks>(ctx, dst);
379306 return ;
380307 }
381308
382309 constexpr int cols_per_block = 32 ;
383310 constexpr int parallel_blocks = 1 ;
384- switch (Q->ne [0 ]) {
385- case 64 :
386- launch_fattn_tile_f16< 64 , cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool (), ctx.stream ());
387- break ;
388- case 128 :
389- launch_fattn_tile_f16<128 , cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool (), ctx.stream ());
390- break ;
391- default :
392- GGML_ASSERT (false );
393- break ;
394- }
311+ launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks>(ctx, dst);
395312}
0 commit comments