@@ -1034,17 +1034,15 @@ ggml_tensor * llm_graph_context::build_attn_mha(
10341034 const bool v_trans = v->nb [1 ] > v->nb [2 ];
10351035
10361036 // split the batch into streams if needed
1037- const auto n_stream = cparams. kv_unified ? 1 : ubatch. n_seqs_unq ;
1037+ const auto n_stream = k-> ne [ 3 ] ;
10381038
10391039 q = ggml_reshape_4d (ctx0, q, q->ne [0 ], q->ne [1 ], q->ne [2 ]/n_stream, n_stream);
10401040
10411041 q = ggml_permute (ctx0, q, 0 , 2 , 1 , 3 );
10421042 k = ggml_permute (ctx0, k, 0 , 2 , 1 , 3 );
10431043 v = ggml_permute (ctx0, v, 0 , 2 , 1 , 3 );
10441044
1045- const auto n_tokens = q->ne [1 ];
1046- const auto n_head = q->ne [2 ];
1047- const auto n_kv = k->ne [1 ];
1045+ const auto n_kv = k->ne [1 ];
10481046
10491047 ggml_tensor * cur;
10501048
@@ -1086,8 +1084,7 @@ ggml_tensor * llm_graph_context::build_attn_mha(
10861084#endif
10871085 }
10881086
1089- // recombine streams
1090- cur = ggml_reshape_2d (ctx0, cur, cur->ne [0 ]*n_head, n_tokens*n_stream);
1087+ cur = ggml_reshape_2d (ctx0, cur, cur->ne [0 ]*cur->ne [1 ], cur->ne [2 ]*cur->ne [3 ]);
10911088 } else {
10921089 ggml_tensor * kq = ggml_mul_mat (ctx0, k, q);
10931090
@@ -1133,7 +1130,7 @@ ggml_tensor * llm_graph_context::build_attn_mha(
11331130 cur = ggml_permute (ctx0, kqv, 0 , 2 , 1 , 3 );
11341131
11351132 // recombine streams
1136- cur = ggml_cont_2d (ctx0, cur, cur->ne [0 ]*n_head, n_tokens*n_stream );
1133+ cur = ggml_cont_2d (ctx0, cur, cur->ne [0 ]*cur-> ne [ 1 ], cur-> ne [ 2 ]*cur-> ne [ 3 ] );
11371134
11381135 if (!cparams.offload_kqv ) {
11391136 // all nodes between the KV store and the attention output are run on the CPU
@@ -1180,6 +1177,10 @@ ggml_tensor * llm_graph_context::build_attn(
11801177
11811178 const auto & kq_mask = inp->get_kq_mask ();
11821179
1180+ // [TAG_NO_CACHE_PAD]
1181+ // TODO: if ubatch.equal_seqs == true, we can split the three tensors below into ubatch.n_seqs_unq streams
1182+ assert (ubatch.equal_seqs == false );
1183+
11831184 ggml_tensor * q = q_cur;
11841185 ggml_tensor * k = k_cur;
11851186 ggml_tensor * v = v_cur;
0 commit comments