Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 12 additions & 13 deletions src/llama-kv-cache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -771,8 +771,8 @@ llama_kv_cache::slot_info llama_kv_cache::find_slot(const llama_ubatch & ubatch,
GGML_ASSERT(ubatch.seq_id [s*n_tokens][0] == seq_id);
}

res.s0 = std::min<llama_seq_id>(res.s0, seq_to_stream[seq_id]);
res.s1 = std::max<llama_seq_id>(res.s1, seq_to_stream[seq_id]);
res.s0 = std::min<uint32_t>(res.s0, seq_to_stream[seq_id]);
res.s1 = std::max<uint32_t>(res.s1, seq_to_stream[seq_id]);

res.strm[s] = seq_to_stream[seq_id];
res.idxs[s].reserve(n_tokens);
Expand Down Expand Up @@ -964,11 +964,11 @@ bool llama_kv_cache::get_has_shift() const {
return result;
}

uint32_t llama_kv_cache::get_n_kv() const {
uint32_t llama_kv_cache::get_n_kv(const slot_info & sinfo) const {
uint32_t result = 0;

for (uint32_t s = 0; s < n_stream; ++s) {
const auto & cells = v_cells[s];
for (uint32_t s = 0; s < sinfo.n_stream(); ++s) {
const auto & cells = v_cells[sinfo.strm[s]];

result = std::max(std::min(cells.size(), std::max(n_pad, GGML_PAD(cells.used_max_p1(), n_pad))), result);
}
Expand Down Expand Up @@ -1017,18 +1017,18 @@ ggml_tensor * llama_kv_cache::get_v(ggml_context * ctx, int32_t il, uint32_t n_k
// note: v->nb[1] <= v->nb[2]
return ggml_view_4d(ctx, v,
hparams.n_embd_head_v, hparams.n_head_kv(il), n_kv, ns,
ggml_row_size(v->type, hparams.n_embd_head_v), // v->nb[1]
ggml_row_size(v->type, n_embd_v_gqa), // v->nb[2]
ggml_row_size(v->type, n_embd_v_gqa*kv_size), // v->nb[3]
ggml_row_size(v->type, hparams.n_embd_head_v), // v->nb[1]
ggml_row_size(v->type, n_embd_v_gqa), // v->nb[2]
ggml_row_size(v->type, n_embd_v_gqa*kv_size), // v->nb[3]
ggml_row_size(v->type, n_embd_v_gqa*kv_size)*sinfo.s0);
}

// note: v->nb[1] > v->nb[2]
return ggml_view_4d(ctx, v,
n_kv, hparams.n_head_kv(il), hparams.n_embd_head_v, ns,
ggml_row_size(v->type, kv_size*hparams.n_embd_head_v), // v->nb[1]
ggml_row_size(v->type, kv_size), // v->nb[2]
ggml_row_size(v->type, kv_size*n_embd_v_gqa), // v->nb[3]
ggml_row_size(v->type, kv_size*hparams.n_embd_head_v), // v->nb[1]
ggml_row_size(v->type, kv_size), // v->nb[2]
ggml_row_size(v->type, kv_size*n_embd_v_gqa), // v->nb[3]
ggml_row_size(v->type, kv_size*n_embd_v_gqa)*sinfo.s0);
}

Expand Down Expand Up @@ -1985,8 +1985,7 @@ bool llama_kv_cache_context::apply() {
}

kv->apply_ubatch(sinfos[i_cur], ubatches[i_cur]);

n_kv = kv->get_n_kv();
n_kv = kv->get_n_kv(sinfos[i_cur]);

return true;
}
Expand Down
6 changes: 3 additions & 3 deletions src/llama-kv-cache.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ class llama_kv_cache : public llama_memory_i {
using idx_vec_t = std::vector<uint32_t>;

// number of streams: ns = s1 - s0 + 1
llama_seq_id s0;
llama_seq_id s1;
uint32_t s0;
uint32_t s1;

std::vector<llama_seq_id> strm; // [ns]
std::vector<idx_vec_t> idxs; // [ns]
Expand Down Expand Up @@ -139,7 +139,7 @@ class llama_kv_cache : public llama_memory_i {
// graph_build API
//

uint32_t get_n_kv() const;
uint32_t get_n_kv(const slot_info & sinfo) const;

// TODO: temporary
bool get_supports_set_rows() const;
Expand Down
Loading