@@ -1133,7 +1133,8 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context*
11331133 struct ggml_tensor * mask = NULL ,
11341134 bool diag_mask_inf = false ,
11351135 bool skip_reshape = false ,
1136- bool flash_attn = false ) {
1136+ bool flash_attn = false , // avoid overflow
1137+ float kv_scale = 1 .0f ) {
11371138 int64_t L_q;
11381139 int64_t L_k;
11391140 int64_t C;
@@ -1175,13 +1176,19 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context*
11751176 if (kv_pad != 0 ) {
11761177 k_in = ggml_pad (ctx, k_in, 0 , kv_pad, 0 , 0 );
11771178 }
1179+ if (kv_scale != 1 .0f ) {
1180+ k_in = ggml_scale (ctx, k_in, kv_scale);
1181+ }
11781182 k_in = ggml_cast (ctx, k_in, GGML_TYPE_F16);
11791183
11801184 v_in = ggml_nn_cont (ctx, ggml_permute (ctx, v_in, 0 , 2 , 1 , 3 ));
11811185 v_in = ggml_reshape_3d (ctx, v_in, d_head, L_k, n_kv_head * N);
11821186 if (kv_pad != 0 ) {
11831187 v_in = ggml_pad (ctx, v_in, 0 , kv_pad, 0 , 0 );
11841188 }
1189+ if (kv_scale != 1 .0f ) {
1190+ v_in = ggml_scale (ctx, v_in, kv_scale);
1191+ }
11851192 v_in = ggml_cast (ctx, v_in, GGML_TYPE_F16);
11861193
11871194 if (mask_in != nullptr ) {
@@ -1205,8 +1212,11 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context*
12051212 mask_in = ggml_cast (ctx, mask_in, GGML_TYPE_F16);
12061213 }
12071214
1208- auto out = ggml_flash_attn_ext (ctx, q_in, k_in, v_in, mask_in, scale, 0 , 0 );
1215+ auto out = ggml_flash_attn_ext (ctx, q_in, k_in, v_in, mask_in, scale / kv_scale , 0 , 0 );
12091216 ggml_flash_attn_ext_set_prec (out, GGML_PREC_F32);
1217+ if (kv_scale != 1 .0f ) {
1218+ out = ggml_scale (ctx, out, 1 .0f / kv_scale);
1219+ }
12101220 return out;
12111221 };
12121222
0 commit comments