@@ -82,11 +82,11 @@ static __global__ void flash_attn_ext_f16(
8282 const int sequence = blockIdx .z / ne02;
8383 const int head = blockIdx .z - sequence*ne02;
8484 const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
85- const float * Q_f = (const float *) (Q + nb03* sequence + nb02* head + nb01*ic0);
86- const half * K_h = (const half *) (K + nb13* sequence + nb12*(head / gqa_ratio));
87- const half * V_h = (const half *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape
88- const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
89- const half2 * mask2 = (const half2 *) maskh;
85+ const float * Q_f = (const float *) (Q + nb03* sequence + nb02* head + nb01*ic0);
86+ const half * K_h = (const half *) (K + nb13* sequence + nb12*(head / gqa_ratio));
87+ const half * V_h = (const half *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape
88+ const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
89+ const half2 * mask2 = (const half2 *) maskh;
9090 const float * sinksf = (const float *) sinks;
9191
9292 const int stride_Q = nb01 / sizeof (float );
@@ -387,21 +387,20 @@ static __global__ void flash_attn_ext_f16(
387387 const float sinkf = sinksf[head];
388388 const half sinkh = __float2half (sinkf);
389389
390- #pragma unroll
390+ #pragma unroll
391391 for (int j0 = 0 ; j0 < ncols; j0 += nwarps) {
392392 const int j = j0 + threadIdx .y ;
393393
394394 if (std::is_same<KQ_acc_t, float >::value) {
395395 float kqmax_new = fmaxf (KQ_max_f[j0/nwarps], sinkf);
396- kqmax_new = warp_reduce_max<warp_size>(kqmax_new);
397396
398397 const float KQ_max_scale = expf (KQ_max_f[j0/nwarps] - kqmax_new);
399398 KQ_max_f[j0/nwarps] = kqmax_new;
400399
401400 KQ_rowsum_f[j0/nwarps] = KQ_rowsum_f[j0/nwarps] * KQ_max_scale + expf (sinkf - KQ_max_f[j0/nwarps]);
402401
403402 const half2 scale_h2 = make_half2 (KQ_max_scale, KQ_max_scale);
404- #pragma unroll
403+ #pragma unroll
405404 for (int i0 = 0 ; i0 < D/2 ; i0 += warp_size) {
406405 const int i = i0 + threadIdx .x ;
407406 if (i0 + warp_size > D/2 && i >= D/2 ) break ;
@@ -410,7 +409,6 @@ static __global__ void flash_attn_ext_f16(
410409 } else {
411410 half kqmax_old = __low2half (KQ_max_h2[j0/nwarps]);
412411 half kqmax_new = fmaxf (kqmax_old, sinkh);
413- kqmax_new = warp_reduce_max<warp_size>(kqmax_new);
414412 KQ_max_h2[j0/nwarps] = __half2half2 (kqmax_new);
415413
416414 const half KQ_max_scale_h = hexp (kqmax_old - kqmax_new);
@@ -420,7 +418,7 @@ static __global__ void flash_attn_ext_f16(
420418 const half val = hexp (sinkh - kqmax_new);
421419 KQ_rowsum_h2[j0/nwarps].x = __hadd (KQ_rowsum_h2[j0/nwarps].x , val);
422420
423- #pragma unroll
421+ #pragma unroll
424422 for (int i0 = 0 ; i0 < D/2 ; i0 += warp_size) {
425423 const int i = i0 + threadIdx .x ;
426424 if (i0 + warp_size > D/2 && i >= D/2 ) break ;
0 commit comments