File tree Expand file tree Collapse file tree 2 files changed +5
-4
lines changed Expand file tree Collapse file tree 2 files changed +5
-4
lines changed Original file line number Diff line number Diff line change @@ -8135,7 +8135,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
81358135 }
81368136
81378137 // V /= S
8138- const float S_inv = 1 .0f /S;
8138+ const float S_inv = S == 0 . 0f ? 0 . 0f : 1 .0f /S;
81398139 ggml_vec_scale_f32 (DV, VKQ32, S_inv);
81408140
81418141 // dst indices
Original file line number Diff line number Diff line change @@ -5201,7 +5201,7 @@ void kernel_flash_attn_ext_impl(
52015201
52025202 device float4 * dst4 = (device float4 *) dst + ((uint64_t )iq3*args.ne2 *args.ne1 + iq2 + (uint64_t )(iq1 + j)*args.ne1 )*DV4;
52035203
5204- const float scale = 1 .0f /S[jj];
5204+ const float scale = S[jj] == 0.0 ? 0 . 0f : 1 .0f /S[jj];
52055205
52065206 if (DV4 % NW == 0 ) {
52075207 FOR_UNROLL (short ii = 0 ; ii < DV4/NW; ++ii) {
@@ -5821,7 +5821,7 @@ void kernel_flash_attn_ext_vec_impl(
58215821 device float4 * dst4 = (device float4 *) dst;
58225822 device float * dst1 = (device float *) dst + nrows*DV*NWG; // the S and M are stored after the results
58235823
5824- const float S = NWG == 1 ? 1 .0f /ss[0 ] : 1 .0f ;
5824+ const float S = NWG == 1 ? (ss[ 0 ] == 0 . 0f ? 0 . 0f : 1 .0f /ss[0 ]) : 1 .0f ;
58255825
58265826 // interleave the workgroup data
58275827 for (short i = tiisg; i < DV4; i += NW) {
@@ -5999,7 +5999,8 @@ kernel void kernel_flash_attn_ext_vec_reduce(
59995999 const float m = simd_max (M);
60006000 const float ms = exp (M - m);
60016001
6002- S = 1 .0f /simd_sum (S*ms);
6002+ S = simd_sum (S*ms);
6003+ S = S == 0 .0f ? 0 .0f : 1 .0f /S;
60036004
60046005 const short DV4 = DV/4 ;
60056006
You can’t perform that action at this time.
0 commit comments