File tree Expand file tree Collapse file tree 1 file changed +4
-3
lines changed
Expand file tree Collapse file tree 1 file changed +4
-3
lines changed Original file line number Diff line number Diff line change @@ -404,7 +404,6 @@ static __global__ void flash_attn_ext_f16(
404404 if (ic0 + j_VKQ >= ne01) {
405405 return ;
406406 }
407- const int j_dst = (ic0 + j_VKQ)*gridDim .y + blockIdx .y ;
408407
409408 float KQ_rowsum_j;
410409 if (std::is_same<KQ_acc_t, float >::value) {
@@ -413,6 +412,8 @@ static __global__ void flash_attn_ext_f16(
413412 KQ_rowsum_j = __low2float (KQ_rowsum_h2[j0/nwarps]) + __high2float (KQ_rowsum_h2[j0/nwarps]);
414413 }
415414
415+ const int j_dst_unrolled = ((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim .y + blockIdx .y ;
416+
416417#pragma unroll
417418 for (int i0 = 0 ; i0 < D; i0 += warp_size) {
418419 const int i = i0 + threadIdx .x ;
@@ -423,7 +424,7 @@ static __global__ void flash_attn_ext_f16(
423424 if (gridDim .y == 1 ) {
424425 dst_val /= KQ_rowsum_j;
425426 }
426- dst[((sequence*ne01 + j_dst)*ne02 + head)* D + tid ] = dst_val;
427+ dst[j_dst_unrolled* D + i ] = dst_val;
427428 }
428429
429430 if (gridDim .y == 1 || threadIdx .x != 0 ) {
@@ -437,7 +438,7 @@ static __global__ void flash_attn_ext_f16(
437438 dst_meta_val.x = __low2float (KQ_max_h2[j0/nwarps]);
438439 }
439440 dst_meta_val.y = KQ_rowsum_j;
440- dst_meta[((ic0 + j_VKQ)* gridDim . z + blockIdx . z ) * gridDim . y + blockIdx . y ] = dst_meta_val;
441+ dst_meta[j_dst_unrolled ] = dst_meta_val;
441442 }
442443#else
443444 GGML_UNUSED (Q); GGML_UNUSED (K); GGML_UNUSED (V); GGML_UNUSED (mask);
You can’t perform that action at this time.
0 commit comments