@@ -7046,16 +7046,16 @@ static void ggml_compute_forward_flash_attn_ext_f16(
70467046 V32 [i_gqa] = (VKQ32[i_gqa] + 1 *DV);
70477047 VKQ16 [i_gqa] = (ggml_fp16_t *) (VKQ32[i_gqa] + 1 *DV);
70487048 Q_q [i_gqa] = (ggml_fp16_t *) (VKQ32[i_gqa] + 2 *DV);
7049-
7049+
70507050 if (v->type == GGML_TYPE_F16) {
70517051 memset (VKQ16[i_gqa], 0 , DV*sizeof (ggml_fp16_t ));
70527052 } else {
70537053 memset (VKQ32[i_gqa], 0 , DV*sizeof (float ));
70547054 }
7055-
7055+
70567056 const float * pq = (const float *) ((char *) q->data + (iq1*nbq1 + (iq2 + i_gqa)*nbq2 + iq3*nbq3));
70577057 q_to_vec_dot (pq, Q_q[i_gqa], DK);
7058-
7058+
70597059 const uint32_t h = iq2 + i_gqa;
70607060 slope[i_gqa] = (max_bias > 0 .0f ) ? h < n_head_log2 ? powf (m0, h + 1 ) : powf (m1, 2 *(h - n_head_log2) + 1 ) : 1 .0f ;
70617061 }
@@ -7083,7 +7083,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
70837083 for (int i_gqa = 0 ; i_gqa < n_gqa; ++i_gqa) {
70847084 const float mv = mp_value_base * slope[i_gqa];
70857085 ggml_compute_forward_flash_attn_ext_f16_one_QKV (
7086- Q_q[i_gqa], k_data, v_data, DK, DV, mv, scale, logit_softcap, v->type ,
7086+ Q_q[i_gqa], k_data, v_data, DK, DV, mv, scale, logit_softcap, v->type ,
70877087 kq_vec_dot, v_to_float, VKQ16[i_gqa], VKQ32[i_gqa], V32[i_gqa], S+i_gqa, M+i_gqa);
70887088 }
70897089 }
@@ -7094,19 +7094,19 @@ static void ggml_compute_forward_flash_attn_ext_f16(
70947094 VKQ32[i][d] = GGML_FP16_TO_FP32 (VKQ16[i][d]);
70957095 }
70967096 }
7097-
7097+
70987098 // V /= S
70997099 const float S_inv = 1 .0f /S[i];
71007100 ggml_vec_scale_f32 (DV, VKQ32[i], S_inv);
7101-
7101+
71027102 // dst indices
71037103 const int i1 = iq1;
71047104 const int i2 = iq2 + i;
71057105 const int i3 = iq3;
7106-
7106+
71077107 // original
71087108 // memcpy((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3), V, nev0*sizeof(float));
7109-
7109+
71107110 // permute(0, 2, 1, 3)
71117111 memcpy ((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, VKQ32[i], nb1);
71127112 }
0 commit comments