@@ -771,8 +771,8 @@ llama_kv_cache::slot_info llama_kv_cache::find_slot(const llama_ubatch & ubatch,
771771 GGML_ASSERT (ubatch.seq_id [s*n_tokens][0 ] == seq_id);
772772 }
773773
774- res.s0 = std::min<llama_seq_id >(res.s0 , seq_to_stream[seq_id]);
775- res.s1 = std::max<llama_seq_id >(res.s1 , seq_to_stream[seq_id]);
774+ res.s0 = std::min<uint32_t >(res.s0 , seq_to_stream[seq_id]);
775+ res.s1 = std::max<uint32_t >(res.s1 , seq_to_stream[seq_id]);
776776
777777 res.strm [s] = seq_to_stream[seq_id];
778778 res.idxs [s].reserve (n_tokens);
@@ -964,11 +964,11 @@ bool llama_kv_cache::get_has_shift() const {
964964 return result;
965965}
966966
967- uint32_t llama_kv_cache::get_n_kv () const {
967+ uint32_t llama_kv_cache::get_n_kv (const slot_info & sinfo ) const {
968968 uint32_t result = 0 ;
969969
970- for (uint32_t s = 0 ; s < n_stream; ++s) {
971- const auto & cells = v_cells[s ];
970+ for (uint32_t s = 0 ; s < sinfo. n_stream () ; ++s) {
971+ const auto & cells = v_cells[sinfo. strm [s] ];
972972
973973 result = std::max (std::min (cells.size (), std::max (n_pad, GGML_PAD (cells.used_max_p1 (), n_pad))), result);
974974 }
@@ -1017,18 +1017,18 @@ ggml_tensor * llama_kv_cache::get_v(ggml_context * ctx, int32_t il, uint32_t n_k
10171017 // note: v->nb[1] <= v->nb[2]
10181018 return ggml_view_4d (ctx, v,
10191019 hparams.n_embd_head_v , hparams.n_head_kv (il), n_kv, ns,
1020- ggml_row_size (v->type , hparams.n_embd_head_v ), // v->nb[1]
1021- ggml_row_size (v->type , n_embd_v_gqa), // v->nb[2]
1022- ggml_row_size (v->type , n_embd_v_gqa*kv_size), // v->nb[3]
1020+ ggml_row_size (v->type , hparams.n_embd_head_v ), // v->nb[1]
1021+ ggml_row_size (v->type , n_embd_v_gqa), // v->nb[2]
1022+ ggml_row_size (v->type , n_embd_v_gqa*kv_size), // v->nb[3]
10231023 ggml_row_size (v->type , n_embd_v_gqa*kv_size)*sinfo.s0 );
10241024 }
10251025
10261026 // note: v->nb[1] > v->nb[2]
10271027 return ggml_view_4d (ctx, v,
10281028 n_kv, hparams.n_head_kv (il), hparams.n_embd_head_v , ns,
1029- ggml_row_size (v->type , kv_size*hparams.n_embd_head_v ), // v->nb[1]
1030- ggml_row_size (v->type , kv_size), // v->nb[2]
1031- ggml_row_size (v->type , kv_size*n_embd_v_gqa), // v->nb[3]
1029+ ggml_row_size (v->type , kv_size*hparams.n_embd_head_v ), // v->nb[1]
1030+ ggml_row_size (v->type , kv_size), // v->nb[2]
1031+ ggml_row_size (v->type , kv_size*n_embd_v_gqa), // v->nb[3]
10321032 ggml_row_size (v->type , kv_size*n_embd_v_gqa)*sinfo.s0 );
10331033}
10341034
@@ -1985,8 +1985,7 @@ bool llama_kv_cache_context::apply() {
19851985 }
19861986
19871987 kv->apply_ubatch (sinfos[i_cur], ubatches[i_cur]);
1988-
1989- n_kv = kv->get_n_kv ();
1988+ n_kv = kv->get_n_kv (sinfos[i_cur]);
19901989
19911990 return true ;
19921991}
0 commit comments