@@ -335,14 +335,21 @@ static __global__ void flash_attn_ext_f16(
335335 for (int j0 = 0 ; j0 < ncols; j0 += nwarps) {
336336 const int j = j0 + threadIdx .y ;
337337
338+ half2 KQ2_tmp[FATTN_KQ_STRIDE/(2 *WARP_SIZE)];
339+ #pragma unroll
340+ for (int k0 = 0 ; k0 < FATTN_KQ_STRIDE/2 ; k0 += WARP_SIZE) {
341+ const int k = k0 + threadIdx .x ;
342+
343+ KQ2_tmp[k0/WARP_SIZE] = KQ2[j*(kqs_padded/2 ) + k];
344+ }
345+
338346 half2 KQ_max_new = KQ_max[j0/nwarps];
339347#pragma unroll
340348 for (int k0 = 0 ; k0 < FATTN_KQ_STRIDE/2 ; k0 += WARP_SIZE) {
341349 const int k = k0 + threadIdx .x ;
342- half2 val = KQ2[j*(kqs_padded/2 ) + k];
343- val += mask ? mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2 (0 .0f , 0 .0f );
344- KQ_max_new = __hmax2 (KQ_max_new, val);
345- KQ2[j*(kqs_padded/2 ) + k] = val;
350+
351+ KQ2_tmp[k0/WARP_SIZE] += mask ? mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2 (0 .0f , 0 .0f );
352+ KQ_max_new = __hmax2 (KQ_max_new, KQ2_tmp[k0/WARP_SIZE]);
346353 }
347354 KQ_max_new = __half2half2 (warp_reduce_max (__hmax (__low2half (KQ_max_new), __high2half (KQ_max_new))));
348355 const half2 diff = KQ_max[j0/nwarps] - KQ_max_new;
@@ -356,13 +363,12 @@ static __global__ void flash_attn_ext_f16(
356363 for (int k0 = 0 ; k0 < FATTN_KQ_STRIDE/2 ; k0 += WARP_SIZE) {
357364 const int k = k0 + threadIdx .x ;
358365
359- half2 val = KQ2[j*(kqs_padded/2 ) + k];
360- const half2 diff = val - KQ_max[j0/nwarps];
361- val = h2exp (diff);
366+ const half2 diff = KQ2_tmp[k0/WARP_SIZE] - KQ_max[j0/nwarps];
367+ KQ2_tmp[k0/WARP_SIZE] = h2exp (diff);
362368 const uint ftz_mask = __hgt2_mask (diff, make_half2 (SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD));
363- *((uint *) &val ) &= ftz_mask;
364- KQ_rowsum_add += val ;
365- KQ2[j*(kqs_padded/2 ) + k] = val ;
369+ *((uint *) &KQ2_tmp[k0/WARP_SIZE] ) &= ftz_mask;
370+ KQ_rowsum_add += KQ2_tmp[k0/WARP_SIZE] ;
371+ KQ2[j*(kqs_padded/2 ) + k] = KQ2_tmp[k0/WARP_SIZE] ;
366372 }
367373 KQ_rowsum_add = warp_reduce_sum (KQ_rowsum_add);
368374
0 commit comments