@@ -1119,9 +1119,9 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention(struct ggml_context* ctx
11191119 return kqv;
11201120}
11211121
1122- // q: [N, L_q, C] or [N*n_head, L_q, d_head]
1123- // k: [N, L_k, C ] or [N*n_head , L_k, d_head]
1124- // v: [N, L_k, C ] or [N, L_k, n_head , d_head]
1122+ // q: [N, L_q, C(n_head*d_head) ] or [N*n_head, L_q, d_head]
1123+ // k: [N, L_k, n_kv_head*d_head ] or [N*n_kv_head , L_k, d_head]
1124+ // v: [N, L_k, n_kv_head*d_head ] or [N, L_k, n_kv_head , d_head]
11251125// mask: [N, L_q, L_k]
11261126// return: [N, L_q, C]
11271127__STATIC_INLINE__ struct ggml_tensor * ggml_nn_attention_ext (struct ggml_context * ctx,
@@ -1139,27 +1139,31 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context*
11391139 int64_t C;
11401140 int64_t N;
11411141 int64_t d_head;
1142+ int64_t n_kv_head;
11421143 if (!skip_reshape) {
1143- L_q = q->ne [1 ];
1144- L_k = k->ne [1 ];
1145- C = q->ne [0 ];
1146- N = q->ne [2 ];
1147- d_head = C / n_head;
1148- q = ggml_reshape_4d (ctx, q, d_head, n_head, L_q, N); // [N, L_q, n_head, d_head]
1149- q = ggml_nn_cont (ctx, ggml_permute (ctx, q, 0 , 2 , 1 , 3 )); // [N, n_head, L_q, d_head]
1150- q = ggml_reshape_3d (ctx, q, d_head, L_q, n_head * N); // [N * n_head, L_q, d_head]
1151-
1152- k = ggml_reshape_4d (ctx, k, d_head, n_head, L_k, N); // [N, L_k, n_head, d_head]
1153- k = ggml_nn_cont (ctx, ggml_permute (ctx, k, 0 , 2 , 1 , 3 )); // [N, n_head, L_k, d_head]
1154- k = ggml_reshape_3d (ctx, k, d_head, L_k, n_head * N); // [N * n_head, L_k, d_head]
1155-
1156- v = ggml_reshape_4d (ctx, v, d_head, n_head, L_k, N); // [N, L_k, n_head, d_head]
1144+ L_q = q->ne [1 ];
1145+ L_k = k->ne [1 ];
1146+ C = q->ne [0 ];
1147+ N = q->ne [2 ];
1148+ d_head = C / n_head;
1149+ n_kv_head = k->ne [0 ] / d_head;
1150+
1151+ q = ggml_reshape_4d (ctx, q, d_head, n_head, L_q, N); // [N, L_q, n_head, d_head]
1152+ q = ggml_nn_cont (ctx, ggml_permute (ctx, q, 0 , 2 , 1 , 3 )); // [N, n_head, L_q, d_head]
1153+ q = ggml_reshape_3d (ctx, q, d_head, L_q, n_head * N); // [N * n_head, L_q, d_head]
1154+
1155+ k = ggml_reshape_4d (ctx, k, d_head, n_kv_head, L_k, N); // [N, L_k, n_kv_head, d_head]
1156+ k = ggml_nn_cont (ctx, ggml_permute (ctx, k, 0 , 2 , 1 , 3 )); // [N, n_kv_head, L_k, d_head]
1157+ k = ggml_reshape_3d (ctx, k, d_head, L_k, n_kv_head * N); // [N * n_kv_head, L_k, d_head]
1158+
1159+ v = ggml_reshape_4d (ctx, v, d_head, n_kv_head, L_k, N); // [N, L_k, n_kv_head, d_head]
11571160 } else {
1158- L_q = q->ne [1 ];
1159- L_k = k->ne [1 ];
1160- d_head = v->ne [0 ];
1161- N = v->ne [3 ];
1162- C = d_head * n_head;
1161+ L_q = q->ne [1 ];
1162+ L_k = k->ne [1 ];
1163+ d_head = v->ne [0 ];
1164+ N = v->ne [3 ];
1165+ n_kv_head = k->ne [2 ] / N;
1166+ C = d_head * n_head;
11631167 }
11641168
11651169 float scale = (1 .0f / sqrt ((float )d_head));
@@ -1174,7 +1178,7 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context*
11741178 k_in = ggml_cast (ctx, k_in, GGML_TYPE_F16);
11751179
11761180 v_in = ggml_nn_cont (ctx, ggml_permute (ctx, v_in, 0 , 2 , 1 , 3 ));
1177- v_in = ggml_reshape_3d (ctx, v_in, d_head, L_k, n_head * N);
1181+ v_in = ggml_reshape_3d (ctx, v_in, d_head, L_k, n_kv_head * N);
11781182 if (kv_pad != 0 ) {
11791183 v_in = ggml_pad (ctx, v_in, 0 , kv_pad, 0 , 0 );
11801184 }
@@ -1232,8 +1236,8 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context*
12321236 // if (flash_attn) {
12331237 // LOG_DEBUG("fallback to default attention, L_q:%d L_k:%d n_head:%d C:%d d_head:%d N:%d", L_q, L_k, n_head, C, d_head, N);
12341238 // }
1235- v = ggml_nn_cont (ctx, ggml_permute (ctx, v, 1 , 2 , 0 , 3 )); // [N, n_head , d_head, L_k]
1236- v = ggml_reshape_3d (ctx, v, L_k, d_head, n_head * N); // [N * n_head , d_head, L_k]
1239+ v = ggml_nn_cont (ctx, ggml_permute (ctx, v, 1 , 2 , 0 , 3 )); // [N, n_kv_head , d_head, L_k]
1240+ v = ggml_reshape_3d (ctx, v, L_k, d_head, n_kv_head * N); // [N * n_kv_head , d_head, L_k]
12371241
12381242 auto kq = ggml_mul_mat (ctx, k, q); // [N * n_head, L_q, L_k]
12391243 kq = ggml_scale_inplace (ctx, kq, scale);
0 commit comments