Skip to content

Commit 26ade0a

Browse files
committed
feat: Construct hybrid recurrent cache for hybrid recurrent models
This includes a refactor of the create_memory logic to avoid needing to use the arch enum explicitly unless a model needs explicit cache instantiation logic beyond the standard logic for recurrent, hybrid, unified, and iswa. Branch: HybridRecurrentCache Signed-off-by: Gabe Goodhart <[email protected]>
1 parent d6beaef commit 26ade0a

File tree

1 file changed

+63
-47
lines changed

1 file changed

+63
-47
lines changed

src/llama-model.cpp

Lines changed: 63 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include "llama-kv-cache-unified.h"
1010
#include "llama-kv-cache-unified-iswa.h"
1111
#include "llama-kv-cache-recurrent.h"
12+
#include "llama-kv-cache-hybrid-recurrent.h"
1213

1314
#include "ggml-cpp.h"
1415

@@ -13202,6 +13203,8 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
1320213203
llama_memory_i * res;
1320313204

1320413205
switch (arch) {
13206+
// Models that need specific instantiation should be handled in the
13207+
// switch statement
1320513208
case LLM_ARCH_BERT:
1320613209
case LLM_ARCH_JINA_BERT_V2:
1320713210
case LLM_ARCH_NOMIC_BERT:
@@ -13210,58 +13213,71 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
1321013213
{
1321113214
res = nullptr;
1321213215
} break;
13213-
case LLM_ARCH_MAMBA:
13214-
case LLM_ARCH_RWKV6:
13215-
case LLM_ARCH_RWKV6QWEN2:
13216-
case LLM_ARCH_RWKV7:
13217-
case LLM_ARCH_ARWKV7:
13218-
{
13219-
res = new llama_kv_cache_recurrent(
13220-
*this,
13221-
nullptr,
13222-
GGML_TYPE_F32,
13223-
GGML_TYPE_F32,
13224-
cparams.offload_kqv,
13225-
std::max((uint32_t) 1, cparams.n_seq_max),
13226-
cparams.n_seq_max);
13227-
} break;
13216+
// Models that need standard caching should rely on recurrent/hybrid
13217+
// checks
1322813218
default:
1322913219
{
13230-
const auto padding = llama_kv_cache_unified::get_padding(cparams);
13231-
13232-
cparams.n_ctx = GGML_PAD(cparams.n_ctx, padding);
13233-
13234-
LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx);
13235-
13236-
if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
13237-
GGML_ASSERT(hparams.is_swa_any());
13238-
13239-
res = new llama_kv_cache_unified_iswa(
13240-
*this,
13241-
params.type_k,
13242-
params.type_v,
13243-
!cparams.flash_attn,
13244-
cparams.offload_kqv,
13245-
params.swa_full,
13246-
cparams.n_ctx,
13247-
cparams.n_seq_max,
13248-
cparams.n_ubatch,
13249-
padding);
13250-
} else {
13251-
GGML_ASSERT(!hparams.is_swa_any());
13252-
13253-
res = new llama_kv_cache_unified(
13220+
if (llm_arch_is_recurrent(arch)) {
13221+
res = new llama_kv_cache_recurrent(
1325413222
*this,
1325513223
nullptr,
13256-
params.type_k,
13257-
params.type_v,
13258-
!cparams.flash_attn,
13224+
GGML_TYPE_F32,
13225+
GGML_TYPE_F32,
1325913226
cparams.offload_kqv,
13260-
cparams.n_ctx,
13261-
cparams.n_seq_max,
13262-
padding,
13263-
hparams.n_swa,
13264-
hparams.swa_type);
13227+
std::max((uint32_t) 1, cparams.n_seq_max),
13228+
cparams.n_seq_max);
13229+
} else if (llm_arch_is_hybrid_recurrent(arch)) {
13230+
res = new llama_kv_cache_hybrid_recurrent(
13231+
/* model */ *this,
13232+
/* attn_type_k */ params.type_k,
13233+
/* attn_type_v */ params.type_v,
13234+
/* attn_v_trans */ !cparams.flash_attn,
13235+
/* attn_kv_size */ cparams.n_ctx,
13236+
/* attn_n_pad */ llama_kv_cache_unified::get_padding(cparams),
13237+
/* attn_n_swa */ hparams.n_swa,
13238+
/* attn_swa_type */ hparams.swa_type,
13239+
/* recurrent_type_k */ GGML_TYPE_F32,
13240+
/* recurrent_type_v */ GGML_TYPE_F32,
13241+
/* recurrent_kv_size */ std::max((uint32_t) 1, cparams.n_seq_max),
13242+
/* n_seq_max */ cparams.n_seq_max,
13243+
/* offload */ cparams.offload_kqv);
13244+
} else {
13245+
const auto padding = llama_kv_cache_unified::get_padding(cparams);
13246+
13247+
cparams.n_ctx = GGML_PAD(cparams.n_ctx, padding);
13248+
13249+
LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx);
13250+
13251+
if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
13252+
GGML_ASSERT(hparams.is_swa_any());
13253+
13254+
res = new llama_kv_cache_unified_iswa(
13255+
*this,
13256+
params.type_k,
13257+
params.type_v,
13258+
!cparams.flash_attn,
13259+
cparams.offload_kqv,
13260+
params.swa_full,
13261+
cparams.n_ctx,
13262+
cparams.n_seq_max,
13263+
cparams.n_ubatch,
13264+
padding);
13265+
} else {
13266+
GGML_ASSERT(!hparams.is_swa_any());
13267+
13268+
res = new llama_kv_cache_unified(
13269+
*this,
13270+
nullptr,
13271+
params.type_k,
13272+
params.type_v,
13273+
!cparams.flash_attn,
13274+
cparams.offload_kqv,
13275+
cparams.n_ctx,
13276+
cparams.n_seq_max,
13277+
padding,
13278+
hparams.n_swa,
13279+
hparams.swa_type);
13280+
}
1326513281
}
1326613282
}
1326713283
}

0 commit comments

Comments
 (0)