Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions src/llama-hparams.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -153,3 +153,28 @@ bool llama_hparams::is_swa(uint32_t il) const {

GGML_ABORT("fatal error");
}

bool llama_hparams::has_kv(uint32_t il) const {
if (n_layer_kv_from_start >= 0) {
if (il < (uint32_t) n_layer_kv_from_start) {
return true;
}

return false;
}

// by default, all layers have kv
return true;
}

uint32_t llama_hparams::n_layer_kv() const {
uint32_t res = 0;

for (uint32_t il = 0; il < n_layer; ++il) {
if (has_kv(il)) {
res++;
}
}

return res;
}
6 changes: 6 additions & 0 deletions src/llama-hparams.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ struct llama_hparams {
uint32_t n_embd;
uint32_t n_embd_features = 0;
uint32_t n_layer;
int32_t n_layer_kv_from_start = -1; // if non-negative, the first n_layer_kv_from_start layers have KV cache
uint32_t n_rot;
uint32_t n_embd_head_k; // dimension of keys (d_k). d_q is assumed to be the same, but there are n_head q heads, and only n_head_kv k-v heads
uint32_t n_embd_head_v; // dimension of values (d_v) aka n_embd_head
Expand Down Expand Up @@ -221,6 +222,11 @@ struct llama_hparams {
uint32_t n_pos_per_embd() const;

bool is_swa(uint32_t il) const;

bool has_kv(uint32_t il) const;

// number of layers for which has_kv() returns true
uint32_t n_layer_kv() const;
};

static_assert(std::is_trivially_copyable<llama_hparams>::value, "llama_hparams must be trivially copyable");
Expand Down
31 changes: 24 additions & 7 deletions src/llama-kv-cache-iswa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,26 @@ llama_kv_cache_iswa::llama_kv_cache_iswa(
uint32_t kv_size,
uint32_t n_seq_max,
uint32_t n_ubatch,
uint32_t n_pad) : hparams(model.hparams), unified(unified) {
llama_kv_cache::layer_filter_cb filter_base = [&](int32_t il) { return !model.hparams.is_swa(il); };
llama_kv_cache::layer_filter_cb filter_swa = [&](int32_t il) { return model.hparams.is_swa(il); };
uint32_t n_pad,
const layer_filter_cb & filter,
const layer_reuse_cb & reuse) : hparams(model.hparams), unified(unified) {

// chain filters
const layer_filter_cb filter_base = [&](int32_t il) {
if (filter && !filter(il)) {
return false;
}

return !model.hparams.is_swa(il);
};

const layer_filter_cb filter_swa = [&](int32_t il) {
if (filter && !filter(il)) {
return false;
}

return model.hparams.is_swa(il);
};

const uint32_t size_base = kv_size;

Expand All @@ -41,16 +58,16 @@ llama_kv_cache_iswa::llama_kv_cache_iswa(
LLAMA_LOG_INFO("%s: creating non-SWA KV cache, size = %u cells\n", __func__, size_base);

kv_base = std::make_unique<llama_kv_cache>(
model, std::move(filter_base), type_k, type_v,
model, type_k, type_v,
v_trans, offload, unified, size_base, n_seq_max, n_pad,
0, LLAMA_SWA_TYPE_NONE);
0, LLAMA_SWA_TYPE_NONE, filter_base, reuse);

LLAMA_LOG_INFO("%s: creating SWA KV cache, size = %u cells\n", __func__, size_swa);

kv_swa = std::make_unique<llama_kv_cache>(
model, std::move(filter_swa), type_k, type_v,
model, type_k, type_v,
v_trans, offload, unified, size_swa, n_seq_max, n_pad,
hparams.n_swa, hparams.swa_type);
hparams.n_swa, hparams.swa_type, filter_swa, reuse);
}

void llama_kv_cache_iswa::clear(bool data) {
Expand Down
6 changes: 4 additions & 2 deletions src/llama-kv-cache-iswa.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,13 @@ class llama_kv_cache_iswa : public llama_memory_i {
bool v_trans,
bool offload,
bool swa_full,
bool ,
bool unified,
uint32_t kv_size,
uint32_t n_seq_max,
uint32_t n_ubatch,
uint32_t n_pad);
uint32_t n_pad,
const layer_filter_cb & filter,
const layer_reuse_cb & reuse);

~llama_kv_cache_iswa() = default;

Expand Down
68 changes: 35 additions & 33 deletions src/llama-kv-cache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,40 +17,33 @@
//

llama_kv_cache::llama_kv_cache(
const llama_model & model,
layer_filter_cb && filter,
ggml_type type_k,
ggml_type type_v,
bool v_trans,
bool offload,
bool unified,
uint32_t kv_size,
uint32_t n_seq_max,
uint32_t n_pad,
uint32_t n_swa,
llama_swa_type swa_type) :
const llama_model & model,
ggml_type type_k,
ggml_type type_v,
bool v_trans,
bool offload,
bool unified,
uint32_t kv_size,
uint32_t n_seq_max,
uint32_t n_pad,
uint32_t n_swa,
llama_swa_type swa_type,
const layer_filter_cb & filter,
const layer_reuse_cb & reuse) :
model(model), hparams(model.hparams), v_trans(v_trans),
n_seq_max(n_seq_max), n_stream(unified ? 1 : n_seq_max), n_pad(n_pad), n_swa(n_swa), swa_type(swa_type) {

GGML_ASSERT(kv_size % n_pad == 0);

// TODO: this is temporary until we support passing reuse layer filters [KV_REUSE]
auto n_layer_cache = hparams.n_layer;
if (model.arch == LLM_ARCH_GEMMA3N) {
n_layer_cache = 20;
}
if (model.arch == LLM_ARCH_GLM4_MOE) {
// GLM-4.5: Only process up to last layer, skip final NextN layer
n_layer_cache = hparams.n_layer - hparams.nextn_predict_layers;
}
const uint32_t n_layer_kv = hparams.n_layer_kv();

// create a context for each buffer type
std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map;
auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * {
auto it = ctx_map.find(buft);
if (it == ctx_map.end()) {
ggml_init_params params = {
/*.mem_size =*/ size_t(2u*(1 + n_stream)*n_layer_cache*ggml_tensor_overhead()),
/*.mem_size =*/ size_t(2u*(1 + n_stream)*n_layer_kv*ggml_tensor_overhead()),
/*.mem_buffer =*/ NULL,
/*.no_alloc =*/ true,
};
Expand Down Expand Up @@ -97,9 +90,14 @@ llama_kv_cache::llama_kv_cache(
__func__, hparams.n_embd_v_gqa_max());
}

for (uint32_t il = 0; il < n_layer_cache; il++) {
for (uint32_t il = 0; il < hparams.n_layer; il++) {
if (!hparams.has_kv(il)) {
LLAMA_LOG_DEBUG("%s: layer %3d: does not have KV cache\n", __func__, il);
continue;
}

if (filter && !filter(il)) {
LLAMA_LOG_DEBUG("%s: layer %3d: skipped\n", __func__, il);
LLAMA_LOG_DEBUG("%s: layer %3d: filtered\n", __func__, il);
continue;
}

Expand Down Expand Up @@ -147,23 +145,27 @@ llama_kv_cache::llama_kv_cache(
layers.push_back({ il, k, v, k_stream, v_stream, });
}

// TODO: this is temporary until we support passing reuse layer filters [KV_REUSE]
if (model.arch == LLM_ARCH_GEMMA3N) {
LLAMA_LOG_DEBUG("%s: GEMMA3N: reuse layers [%d, %d]\n", __func__, n_layer_cache, hparams.n_layer - 1);
if (reuse) {
LLAMA_LOG_DEBUG("%s: reusing layers:\n", __func__);

for (uint32_t il = n_layer_cache; il < hparams.n_layer; il++) {
if (filter && !filter(il)) {
LLAMA_LOG_DEBUG("%s: layer %3d: skipped\n", __func__, il);
for (uint32_t il = 0; il < hparams.n_layer; il++) {
const int32_t il_reuse = reuse(il);

if (il_reuse < 0) {
LLAMA_LOG_DEBUG("%s: - layer %3d: no reuse\n", __func__, il);
continue;
}

const bool is_swa = hparams.is_swa(il);
const uint32_t il_reuse = n_layer_cache - (is_swa ? 2 : 1);
if (filter && !filter(il)) {
LLAMA_LOG_DEBUG("%s: - layer %3d: filtered\n", __func__, il);
continue;
}

GGML_ASSERT(map_layer_ids.find(il_reuse) != map_layer_ids.end());

map_layer_ids[il] = map_layer_ids[il_reuse];

LLAMA_LOG_DEBUG("%s: layer %3d: reuse layer %d, isw = %d\n", __func__, il, il_reuse, is_swa);
LLAMA_LOG_DEBUG("%s: - layer %3d: reuse layer %d, is_swa = %d\n", __func__, il, il_reuse, hparams.is_swa(il));
}
}

Expand Down
28 changes: 13 additions & 15 deletions src/llama-kv-cache.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,6 @@ class llama_kv_cache : public llama_memory_i {
public:
static uint32_t get_padding(const llama_cparams & cparams);

// this callback is used to filter out layers that should not be included in the cache
using layer_filter_cb = std::function<bool(int32_t il)>;

struct stream_copy_info {
bool empty() const {
assert(ssrc.size() == sdst.size());
Expand Down Expand Up @@ -82,18 +79,19 @@ class llama_kv_cache : public llama_memory_i {
using slot_info_vec_t = std::vector<slot_info>;

llama_kv_cache(
const llama_model & model,
layer_filter_cb && filter,
ggml_type type_k,
ggml_type type_v,
bool v_trans,
bool offload,
bool unified,
uint32_t kv_size,
uint32_t n_seq_max,
uint32_t n_pad,
uint32_t n_swa,
llama_swa_type swa_type);
const llama_model & model,
ggml_type type_k,
ggml_type type_v,
bool v_trans,
bool offload,
bool unified,
uint32_t kv_size,
uint32_t n_seq_max,
uint32_t n_pad,
uint32_t n_swa,
llama_swa_type swa_type,
const layer_filter_cb & filter,
const layer_reuse_cb & reuse);

~llama_kv_cache() = default;

Expand Down
57 changes: 29 additions & 28 deletions src/llama-memory-hybrid.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,32 +9,29 @@
//

llama_memory_hybrid::llama_memory_hybrid(
const llama_model & model,
/* attn */
ggml_type type_k,
ggml_type type_v,
bool v_trans,
uint32_t kv_size,
uint32_t n_pad,
uint32_t n_swa,
llama_swa_type swa_type,
/* recurrent */
ggml_type type_r,
ggml_type type_s,
uint32_t rs_size,
/* common */
uint32_t n_seq_max,
bool offload,
bool unified,
/* layer filters */
layer_filter_cb && filter_attn,
layer_filter_cb && filter_recr) :
const llama_model & model,
/* attn */
ggml_type type_k,
ggml_type type_v,
bool v_trans,
uint32_t kv_size,
uint32_t n_pad,
uint32_t n_swa,
llama_swa_type swa_type,
/* recurrent */
ggml_type type_r,
ggml_type type_s,
uint32_t rs_size,
/* common */
uint32_t n_seq_max,
bool offload,
bool unified,
/* layer filters */
const layer_filter_cb & filter_attn,
const layer_filter_cb & filter_recr) :
hparams(model.hparams),
mem_attn(new llama_kv_cache(
model,
filter_attn == nullptr ?
[&](int32_t il) { return !hparams.is_recurrent(il); }
: filter_attn,
type_k,
type_v,
v_trans,
Expand All @@ -44,18 +41,22 @@ llama_memory_hybrid::llama_memory_hybrid(
n_seq_max,
n_pad,
n_swa,
swa_type
swa_type,
filter_attn == nullptr ?
[&](int32_t il) { return !hparams.is_recurrent(il); }
: filter_attn,
nullptr
)),
mem_recr(new llama_memory_recurrent(
model,
filter_recr == nullptr ?
[&](int32_t il) { return hparams.is_recurrent(il); }
: filter_recr,
type_r,
type_s,
offload,
rs_size,
n_seq_max
n_seq_max,
filter_recr == nullptr ?
[&](int32_t il) { return hparams.is_recurrent(il); }
: filter_recr
)) {}

llama_memory_context_ptr llama_memory_hybrid::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) {
Expand Down
40 changes: 18 additions & 22 deletions src/llama-memory-hybrid.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,31 +18,27 @@

class llama_memory_hybrid : public llama_memory_i {
public:

// this callback is used to filter out layers that should not be included in the cache
using layer_filter_cb = std::function<bool(int32_t il)>;

llama_memory_hybrid(
const llama_model & model,
/* attn */
ggml_type type_k,
ggml_type type_v,
bool v_trans,
uint32_t kv_size,
uint32_t n_pad,
uint32_t n_swa,
llama_swa_type swa_type,
/* recurrent */
ggml_type type_r,
ggml_type type_s,
uint32_t rs_size,
/* common */
uint32_t n_seq_max,
bool offload,
bool unified,
/* layer filters */
layer_filter_cb && filter_attn = nullptr,
layer_filter_cb && filter_recr = nullptr);
ggml_type type_k,
ggml_type type_v,
bool v_trans,
uint32_t kv_size,
uint32_t n_pad,
uint32_t n_swa,
llama_swa_type swa_type,
/* recurrent */
ggml_type type_r,
ggml_type type_s,
uint32_t rs_size,
/* common */
uint32_t n_seq_max,
bool offload,
bool unified,
/* layer filters */
const layer_filter_cb & filter_attn = nullptr,
const layer_filter_cb & filter_recr = nullptr);

~llama_memory_hybrid() = default;

Expand Down
Loading