Skip to content

Commit 1bded5a

Browse files
authored
kv-cache : better estimate of n_kv for multi-sequence batches (ggml-org#15610)
ggml-ci
1 parent 1e74897 commit 1bded5a

File tree

2 files changed

+15
-16
lines changed

2 files changed

+15
-16
lines changed

src/llama-kv-cache.cpp

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -771,8 +771,8 @@ llama_kv_cache::slot_info llama_kv_cache::find_slot(const llama_ubatch & ubatch,
771771
GGML_ASSERT(ubatch.seq_id [s*n_tokens][0] == seq_id);
772772
}
773773

774-
res.s0 = std::min<llama_seq_id>(res.s0, seq_to_stream[seq_id]);
775-
res.s1 = std::max<llama_seq_id>(res.s1, seq_to_stream[seq_id]);
774+
res.s0 = std::min<uint32_t>(res.s0, seq_to_stream[seq_id]);
775+
res.s1 = std::max<uint32_t>(res.s1, seq_to_stream[seq_id]);
776776

777777
res.strm[s] = seq_to_stream[seq_id];
778778
res.idxs[s].reserve(n_tokens);
@@ -964,11 +964,11 @@ bool llama_kv_cache::get_has_shift() const {
964964
return result;
965965
}
966966

967-
uint32_t llama_kv_cache::get_n_kv() const {
967+
uint32_t llama_kv_cache::get_n_kv(const slot_info & sinfo) const {
968968
uint32_t result = 0;
969969

970-
for (uint32_t s = 0; s < n_stream; ++s) {
971-
const auto & cells = v_cells[s];
970+
for (uint32_t s = 0; s < sinfo.n_stream(); ++s) {
971+
const auto & cells = v_cells[sinfo.strm[s]];
972972

973973
result = std::max(std::min(cells.size(), std::max(n_pad, GGML_PAD(cells.used_max_p1(), n_pad))), result);
974974
}
@@ -1017,18 +1017,18 @@ ggml_tensor * llama_kv_cache::get_v(ggml_context * ctx, int32_t il, uint32_t n_k
10171017
// note: v->nb[1] <= v->nb[2]
10181018
return ggml_view_4d(ctx, v,
10191019
hparams.n_embd_head_v, hparams.n_head_kv(il), n_kv, ns,
1020-
ggml_row_size(v->type, hparams.n_embd_head_v), // v->nb[1]
1021-
ggml_row_size(v->type, n_embd_v_gqa), // v->nb[2]
1022-
ggml_row_size(v->type, n_embd_v_gqa*kv_size), // v->nb[3]
1020+
ggml_row_size(v->type, hparams.n_embd_head_v), // v->nb[1]
1021+
ggml_row_size(v->type, n_embd_v_gqa), // v->nb[2]
1022+
ggml_row_size(v->type, n_embd_v_gqa*kv_size), // v->nb[3]
10231023
ggml_row_size(v->type, n_embd_v_gqa*kv_size)*sinfo.s0);
10241024
}
10251025

10261026
// note: v->nb[1] > v->nb[2]
10271027
return ggml_view_4d(ctx, v,
10281028
n_kv, hparams.n_head_kv(il), hparams.n_embd_head_v, ns,
1029-
ggml_row_size(v->type, kv_size*hparams.n_embd_head_v), // v->nb[1]
1030-
ggml_row_size(v->type, kv_size), // v->nb[2]
1031-
ggml_row_size(v->type, kv_size*n_embd_v_gqa), // v->nb[3]
1029+
ggml_row_size(v->type, kv_size*hparams.n_embd_head_v), // v->nb[1]
1030+
ggml_row_size(v->type, kv_size), // v->nb[2]
1031+
ggml_row_size(v->type, kv_size*n_embd_v_gqa), // v->nb[3]
10321032
ggml_row_size(v->type, kv_size*n_embd_v_gqa)*sinfo.s0);
10331033
}
10341034

@@ -1985,8 +1985,7 @@ bool llama_kv_cache_context::apply() {
19851985
}
19861986

19871987
kv->apply_ubatch(sinfos[i_cur], ubatches[i_cur]);
1988-
1989-
n_kv = kv->get_n_kv();
1988+
n_kv = kv->get_n_kv(sinfos[i_cur]);
19901989

19911990
return true;
19921991
}

src/llama-kv-cache.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,8 @@ class llama_kv_cache : public llama_memory_i {
3838
using idx_vec_t = std::vector<uint32_t>;
3939

4040
// number of streams: ns = s1 - s0 + 1
41-
llama_seq_id s0;
42-
llama_seq_id s1;
41+
uint32_t s0;
42+
uint32_t s1;
4343

4444
std::vector<llama_seq_id> strm; // [ns]
4545
std::vector<idx_vec_t> idxs; // [ns]
@@ -139,7 +139,7 @@ class llama_kv_cache : public llama_memory_i {
139139
// graph_build API
140140
//
141141

142-
uint32_t get_n_kv() const;
142+
uint32_t get_n_kv(const slot_info & sinfo) const;
143143

144144
// TODO: temporary
145145
bool get_supports_set_rows() const;

0 commit comments

Comments
 (0)