@@ -771,8 +771,8 @@ llama_kv_cache::slot_info llama_kv_cache::find_slot(const llama_ubatch & ubatch,
771
771
GGML_ASSERT (ubatch.seq_id [s*n_tokens][0 ] == seq_id);
772
772
}
773
773
774
- res.s0 = std::min<llama_seq_id >(res.s0 , seq_to_stream[seq_id]);
775
- res.s1 = std::max<llama_seq_id >(res.s1 , seq_to_stream[seq_id]);
774
+ res.s0 = std::min<uint32_t >(res.s0 , seq_to_stream[seq_id]);
775
+ res.s1 = std::max<uint32_t >(res.s1 , seq_to_stream[seq_id]);
776
776
777
777
res.strm [s] = seq_to_stream[seq_id];
778
778
res.idxs [s].reserve (n_tokens);
@@ -964,11 +964,11 @@ bool llama_kv_cache::get_has_shift() const {
964
964
return result;
965
965
}
966
966
967
- uint32_t llama_kv_cache::get_n_kv () const {
967
+ uint32_t llama_kv_cache::get_n_kv (const slot_info & sinfo ) const {
968
968
uint32_t result = 0 ;
969
969
970
- for (uint32_t s = 0 ; s < n_stream; ++s) {
971
- const auto & cells = v_cells[s ];
970
+ for (uint32_t s = 0 ; s < sinfo. n_stream () ; ++s) {
971
+ const auto & cells = v_cells[sinfo. strm [s] ];
972
972
973
973
result = std::max (std::min (cells.size (), std::max (n_pad, GGML_PAD (cells.used_max_p1 (), n_pad))), result);
974
974
}
@@ -1017,18 +1017,18 @@ ggml_tensor * llama_kv_cache::get_v(ggml_context * ctx, int32_t il, uint32_t n_k
1017
1017
// note: v->nb[1] <= v->nb[2]
1018
1018
return ggml_view_4d (ctx, v,
1019
1019
hparams.n_embd_head_v , hparams.n_head_kv (il), n_kv, ns,
1020
- ggml_row_size (v->type , hparams.n_embd_head_v ), // v->nb[1]
1021
- ggml_row_size (v->type , n_embd_v_gqa), // v->nb[2]
1022
- ggml_row_size (v->type , n_embd_v_gqa*kv_size), // v->nb[3]
1020
+ ggml_row_size (v->type , hparams.n_embd_head_v ), // v->nb[1]
1021
+ ggml_row_size (v->type , n_embd_v_gqa), // v->nb[2]
1022
+ ggml_row_size (v->type , n_embd_v_gqa*kv_size), // v->nb[3]
1023
1023
ggml_row_size (v->type , n_embd_v_gqa*kv_size)*sinfo.s0 );
1024
1024
}
1025
1025
1026
1026
// note: v->nb[1] > v->nb[2]
1027
1027
return ggml_view_4d (ctx, v,
1028
1028
n_kv, hparams.n_head_kv (il), hparams.n_embd_head_v , ns,
1029
- ggml_row_size (v->type , kv_size*hparams.n_embd_head_v ), // v->nb[1]
1030
- ggml_row_size (v->type , kv_size), // v->nb[2]
1031
- ggml_row_size (v->type , kv_size*n_embd_v_gqa), // v->nb[3]
1029
+ ggml_row_size (v->type , kv_size*hparams.n_embd_head_v ), // v->nb[1]
1030
+ ggml_row_size (v->type , kv_size), // v->nb[2]
1031
+ ggml_row_size (v->type , kv_size*n_embd_v_gqa), // v->nb[3]
1032
1032
ggml_row_size (v->type , kv_size*n_embd_v_gqa)*sinfo.s0 );
1033
1033
}
1034
1034
@@ -1985,8 +1985,7 @@ bool llama_kv_cache_context::apply() {
1985
1985
}
1986
1986
1987
1987
kv->apply_ubatch (sinfos[i_cur], ubatches[i_cur]);
1988
-
1989
- n_kv = kv->get_n_kv ();
1988
+ n_kv = kv->get_n_kv (sinfos[i_cur]);
1990
1989
1991
1990
return true ;
1992
1991
}
0 commit comments