Skip to content

Commit 477911f

Browse files
committed
fix qwen image flash attn
1 parent cf19c6e commit 477911f

File tree

3 files changed

+17
-6
lines changed

3 files changed

+17
-6
lines changed

flux.hpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,14 +120,15 @@ namespace Flux {
120120
struct ggml_tensor* v,
121121
struct ggml_tensor* pe,
122122
struct ggml_tensor* mask,
123-
bool flash_attn) {
123+
bool flash_attn,
124+
float kv_scale = 1.0f) {
124125
// q,k,v: [N, L, n_head, d_head]
125126
// pe: [L, d_head/2, 2, 2]
126127
// return: [N, L, n_head*d_head]
127128
q = apply_rope(ctx, q, pe); // [N*n_head, L, d_head]
128129
k = apply_rope(ctx, k, pe); // [N*n_head, L, d_head]
129130

130-
auto x = ggml_nn_attention_ext(ctx, backend, q, k, v, v->ne[1], mask, false, true, flash_attn); // [N, L, n_head*d_head]
131+
auto x = ggml_nn_attention_ext(ctx, backend, q, k, v, v->ne[1], mask, false, true, flash_attn, kv_scale); // [N, L, n_head*d_head]
131132
return x;
132133
}
133134

ggml_extend.hpp

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1133,7 +1133,8 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context*
11331133
struct ggml_tensor* mask = NULL,
11341134
bool diag_mask_inf = false,
11351135
bool skip_reshape = false,
1136-
bool flash_attn = false) {
1136+
bool flash_attn = false, // avoid overflow
1137+
float kv_scale = 1.0f) {
11371138
int64_t L_q;
11381139
int64_t L_k;
11391140
int64_t C;
@@ -1175,13 +1176,19 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context*
11751176
if (kv_pad != 0) {
11761177
k_in = ggml_pad(ctx, k_in, 0, kv_pad, 0, 0);
11771178
}
1179+
if (kv_scale != 1.0f) {
1180+
k_in = ggml_scale(ctx, k_in, kv_scale);
1181+
}
11781182
k_in = ggml_cast(ctx, k_in, GGML_TYPE_F16);
11791183

11801184
v_in = ggml_nn_cont(ctx, ggml_permute(ctx, v_in, 0, 2, 1, 3));
11811185
v_in = ggml_reshape_3d(ctx, v_in, d_head, L_k, n_kv_head * N);
11821186
if (kv_pad != 0) {
11831187
v_in = ggml_pad(ctx, v_in, 0, kv_pad, 0, 0);
11841188
}
1189+
if (kv_scale != 1.0f) {
1190+
v_in = ggml_scale(ctx, v_in, kv_scale);
1191+
}
11851192
v_in = ggml_cast(ctx, v_in, GGML_TYPE_F16);
11861193

11871194
if (mask_in != nullptr) {
@@ -1205,8 +1212,11 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context*
12051212
mask_in = ggml_cast(ctx, mask_in, GGML_TYPE_F16);
12061213
}
12071214

1208-
auto out = ggml_flash_attn_ext(ctx, q_in, k_in, v_in, mask_in, scale, 0, 0);
1215+
auto out = ggml_flash_attn_ext(ctx, q_in, k_in, v_in, mask_in, scale / kv_scale, 0, 0);
12091216
ggml_flash_attn_ext_set_prec(out, GGML_PREC_F32);
1217+
if (kv_scale != 1.0f) {
1218+
out = ggml_scale(ctx, out, 1.0f / kv_scale);
1219+
}
12101220
return out;
12111221
};
12121222

qwen_image.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -156,8 +156,8 @@ namespace Qwen {
156156
auto k = ggml_concat(ctx, txt_k, img_k, 2); // [N, n_txt_token + n_img_token, n_head, d_head]
157157
auto v = ggml_concat(ctx, txt_v, img_v, 2); // [N, n_txt_token + n_img_token, n_head, d_head]
158158

159-
auto attn = Flux::attention(ctx, backend, q, k, v, pe, mask, flash_attn); // [N, n_txt_token + n_img_token, n_head*d_head]
160-
attn = ggml_cont(ctx, ggml_permute(ctx, attn, 0, 2, 1, 3)); // [n_txt_token + n_img_token, N, hidden_size]
159+
auto attn = Flux::attention(ctx, backend, q, k, v, pe, mask, flash_attn, (1.0f / 128.f)); // [N, n_txt_token + n_img_token, n_head*d_head]
160+
attn = ggml_cont(ctx, ggml_permute(ctx, attn, 0, 2, 1, 3)); // [n_txt_token + n_img_token, N, hidden_size]
161161
auto txt_attn_out = ggml_view_3d(ctx,
162162
attn,
163163
attn->ne[0],

0 commit comments

Comments
 (0)