@@ -983,17 +983,15 @@ ggml_tensor * llm_graph_context::build_attn_mha(
983983 const bool v_trans = v->nb [1 ] > v->nb [2 ];
984984
985985 // split the batch into streams if needed
986- const auto n_stream = cparams. kv_unified ? 1 : ubatch. n_seqs_unq ;
986+ const auto n_stream = k-> ne [ 3 ] ;
987987
988988 q = ggml_reshape_4d (ctx0, q, q->ne [0 ], q->ne [1 ], q->ne [2 ]/n_stream, n_stream);
989989
990990 q = ggml_permute (ctx0, q, 0 , 2 , 1 , 3 );
991991 k = ggml_permute (ctx0, k, 0 , 2 , 1 , 3 );
992992 v = ggml_permute (ctx0, v, 0 , 2 , 1 , 3 );
993993
994- const auto n_tokens = q->ne [1 ];
995- const auto n_head = q->ne [2 ];
996- const auto n_kv = k->ne [1 ];
994+ const auto n_kv = k->ne [1 ];
997995
998996 ggml_tensor * cur;
999997
@@ -1035,8 +1033,7 @@ ggml_tensor * llm_graph_context::build_attn_mha(
10351033#endif
10361034 }
10371035
1038- // recombine streams
1039- cur = ggml_reshape_2d (ctx0, cur, cur->ne [0 ]*n_head, n_tokens*n_stream);
1036+ cur = ggml_reshape_2d (ctx0, cur, cur->ne [0 ]*cur->ne [1 ], cur->ne [2 ]*cur->ne [3 ]);
10401037 } else {
10411038 ggml_tensor * kq = ggml_mul_mat (ctx0, k, q);
10421039
@@ -1082,7 +1079,7 @@ ggml_tensor * llm_graph_context::build_attn_mha(
10821079 cur = ggml_permute (ctx0, kqv, 0 , 2 , 1 , 3 );
10831080
10841081 // recombine streams
1085- cur = ggml_cont_2d (ctx0, cur, cur->ne [0 ]*n_head, n_tokens*n_stream );
1082+ cur = ggml_cont_2d (ctx0, cur, cur->ne [0 ]*cur-> ne [ 1 ], cur-> ne [ 2 ]*cur-> ne [ 3 ] );
10861083
10871084 if (!cparams.offload_kqv ) {
10881085 // all nodes between the KV store and the attention output are run on the CPU
@@ -1129,6 +1126,10 @@ ggml_tensor * llm_graph_context::build_attn(
11291126
11301127 const auto & kq_mask = inp->get_kq_mask ();
11311128
1129+ // [TAG_NO_CACHE_PAD]
1130+ // TODO: if ubatch.equal_seqs == true, we can split the three tensors below into ubatch.n_seqs_unq streams
1131+ assert (ubatch.equal_seqs == false );
1132+
11321133 ggml_tensor * q = q_cur;
11331134 ggml_tensor * k = k_cur;
11341135 ggml_tensor * v = v_cur;
0 commit comments