44
55#define FATTN_KQ_STRIDE_TILE_F16 64
66
7- template <int D, int ncols, int nwarps, int parallel_blocks> // D == head size
7+ template <int D, int ncols, int nwarps, int parallel_blocks, bool use_logit_softcap > // D == head size
88#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
99__launch_bounds__ (nwarps*WARP_SIZE, 1 )
1010#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
@@ -20,6 +20,7 @@ static __global__ void flash_attn_tile_ext_f16(
2020 const float m0,
2121 const float m1,
2222 const uint32_t n_head_log2,
23+ const float logit_softcap,
2324 const int ne00,
2425 const int ne01,
2526 const int ne02,
@@ -44,6 +45,12 @@ static __global__ void flash_attn_tile_ext_f16(
4445 const int ne2,
4546 const int ne3) {
4647#ifdef FP16_AVAILABLE
48+ // Skip unused kernel variants for faster compilation:
49+ if (use_logit_softcap && !(D == 128 || D == 256 )) {
50+ NO_DEVICE_CODE;
51+ return ;
52+ }
53+
4754 // In this kernel Q, K, V are matrices while i, j, k are matrix indices.
4855
4956 const int ic0 = (blockIdx .x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on.
@@ -154,7 +161,13 @@ static __global__ void flash_attn_tile_ext_f16(
154161 for (int j_KQ_0 = 0 ; j_KQ_0 < ncols; j_KQ_0 += nwarps) {
155162 const int j_KQ = j_KQ_0 + threadIdx .y ;
156163
157- half sum = __low2half (sum2[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]) + __high2half (sum2[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]);
164+ half sum;
165+ if (use_logit_softcap) {
166+ const float2 tmp = __half22float2 (sum2[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]);
167+ sum = logit_softcap * tanhf (tmp.x + tmp.y );
168+ } else {
169+ sum = __low2half (sum2[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]) + __high2half (sum2[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]);
170+ }
158171 sum += mask ? slopeh*maskh[j_KQ*ne11 + k_VKQ_0 + i_KQ] : __float2half (0 .0f );
159172
160173 kqmax_new[j_KQ_0/nwarps] = ggml_cuda_hmax (kqmax_new[j_KQ_0/nwarps], sum);
@@ -270,20 +283,20 @@ static __global__ void flash_attn_tile_ext_f16(
270283#endif // FP16_AVAILABLE
271284}
272285
273- template <int cols_per_block, int parallel_blocks>
286+ template <int cols_per_block, int parallel_blocks, bool use_logit_softcap >
274287void launch_fattn_tile_f16_64_128 (ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
275288 const ggml_tensor * Q = dst->src [0 ];
276289 switch (Q->ne [0 ]) {
277290 case 64 : {
278291 constexpr int D = 64 ;
279292 constexpr int nwarps = 8 ;
280- fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks>;
293+ fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks, use_logit_softcap >;
281294 launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true , true );
282295 } break ;
283296 case 128 : {
284297 constexpr int D = 128 ;
285298 constexpr int nwarps = 8 ;
286- fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks>;
299+ fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks, use_logit_softcap >;
287300 launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true , true );
288301 } break ;
289302 default : {
@@ -296,24 +309,45 @@ void ggml_cuda_flash_attn_ext_tile_f16(ggml_backend_cuda_context & ctx, ggml_ten
296309 const ggml_tensor * KQV = dst;
297310 const ggml_tensor * Q = dst->src [0 ];
298311
299- const int32_t precision = KQV->op_params [2 ];
312+ const int32_t precision = KQV->op_params [3 ];
300313 GGML_ASSERT (precision == GGML_PREC_DEFAULT);
301314
315+ float logit_softcap;
316+ memcpy (&logit_softcap, (const float *) KQV->op_params + 2 , sizeof (float ));
317+
302318 if (Q->ne [1 ] <= 16 ) {
303319 constexpr int cols_per_block = 16 ;
304320 constexpr int parallel_blocks = 4 ;
305- launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks>(ctx, dst);
321+ if (logit_softcap == 0 .0f ) {
322+ constexpr bool use_logit_softcap = false ;
323+ launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
324+ } else {
325+ constexpr bool use_logit_softcap = true ;
326+ launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
327+ }
306328 return ;
307329 }
308330
309331 if (Q->ne [1 ] <= 32 ) {
310332 constexpr int cols_per_block = 32 ;
311333 constexpr int parallel_blocks = 4 ;
312- launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks>(ctx, dst);
334+ if (logit_softcap == 0 .0f ) {
335+ constexpr bool use_logit_softcap = false ;
336+ launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
337+ } else {
338+ constexpr bool use_logit_softcap = true ;
339+ launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
340+ }
313341 return ;
314342 }
315343
316344 constexpr int cols_per_block = 32 ;
317345 constexpr int parallel_blocks = 1 ;
318- launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks>(ctx, dst);
346+ if (logit_softcap == 0 .0f ) {
347+ constexpr bool use_logit_softcap = false ;
348+ launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
349+ } else {
350+ constexpr bool use_logit_softcap = true ;
351+ launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
352+ }
319353}
0 commit comments