99#include "llama-kv-cache-unified.h"
1010#include "llama-kv-cache-unified-iswa.h"
1111#include "llama-kv-cache-recurrent.h"
12+ #include "llama-kv-cache-hybrid-recurrent.h"
1213
1314#include "ggml-cpp.h"
1415
@@ -13202,6 +13203,8 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
1320213203 llama_memory_i * res;
1320313204
1320413205 switch (arch) {
13206+ // Models that need specific instantiation should be handled in the
13207+ // switch statement
1320513208 case LLM_ARCH_BERT:
1320613209 case LLM_ARCH_JINA_BERT_V2:
1320713210 case LLM_ARCH_NOMIC_BERT:
@@ -13210,58 +13213,71 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
1321013213 {
1321113214 res = nullptr;
1321213215 } break;
13213- case LLM_ARCH_MAMBA:
13214- case LLM_ARCH_RWKV6:
13215- case LLM_ARCH_RWKV6QWEN2:
13216- case LLM_ARCH_RWKV7:
13217- case LLM_ARCH_ARWKV7:
13218- {
13219- res = new llama_kv_cache_recurrent(
13220- *this,
13221- nullptr,
13222- GGML_TYPE_F32,
13223- GGML_TYPE_F32,
13224- cparams.offload_kqv,
13225- std::max((uint32_t) 1, cparams.n_seq_max),
13226- cparams.n_seq_max);
13227- } break;
13216+ // Models that need standard caching should rely on recurrent/hybrid
13217+ // checks
1322813218 default:
1322913219 {
13230- const auto padding = llama_kv_cache_unified::get_padding(cparams);
13231-
13232- cparams.n_ctx = GGML_PAD(cparams.n_ctx, padding);
13233-
13234- LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx);
13235-
13236- if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
13237- GGML_ASSERT(hparams.is_swa_any());
13238-
13239- res = new llama_kv_cache_unified_iswa(
13240- *this,
13241- params.type_k,
13242- params.type_v,
13243- !cparams.flash_attn,
13244- cparams.offload_kqv,
13245- params.swa_full,
13246- cparams.n_ctx,
13247- cparams.n_seq_max,
13248- cparams.n_ubatch,
13249- padding);
13250- } else {
13251- GGML_ASSERT(!hparams.is_swa_any());
13252-
13253- res = new llama_kv_cache_unified(
13220+ if (llm_arch_is_recurrent(arch)) {
13221+ res = new llama_kv_cache_recurrent(
1325413222 *this,
1325513223 nullptr,
13256- params.type_k,
13257- params.type_v,
13258- !cparams.flash_attn,
13224+ GGML_TYPE_F32,
13225+ GGML_TYPE_F32,
1325913226 cparams.offload_kqv,
13260- cparams.n_ctx,
13261- cparams.n_seq_max,
13262- padding,
13263- hparams.n_swa,
13264- hparams.swa_type);
13227+ std::max((uint32_t) 1, cparams.n_seq_max),
13228+ cparams.n_seq_max);
13229+ } else if (llm_arch_is_hybrid_recurrent(arch)) {
13230+ res = new llama_kv_cache_hybrid_recurrent(
13231+ /* model */ *this,
13232+ /* attn_type_k */ params.type_k,
13233+ /* attn_type_v */ params.type_v,
13234+ /* attn_v_trans */ !cparams.flash_attn,
13235+ /* attn_kv_size */ cparams.n_ctx,
13236+ /* attn_n_pad */ llama_kv_cache_unified::get_padding(cparams),
13237+ /* attn_n_swa */ hparams.n_swa,
13238+ /* attn_swa_type */ hparams.swa_type,
13239+ /* recurrent_type_k */ GGML_TYPE_F32,
13240+ /* recurrent_type_v */ GGML_TYPE_F32,
13241+ /* recurrent_kv_size */ std::max((uint32_t) 1, cparams.n_seq_max),
13242+ /* n_seq_max */ cparams.n_seq_max,
13243+ /* offload */ cparams.offload_kqv);
13244+ } else {
13245+ const auto padding = llama_kv_cache_unified::get_padding(cparams);
13246+
13247+ cparams.n_ctx = GGML_PAD(cparams.n_ctx, padding);
13248+
13249+ LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx);
13250+
13251+ if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
13252+ GGML_ASSERT(hparams.is_swa_any());
13253+
13254+ res = new llama_kv_cache_unified_iswa(
13255+ *this,
13256+ params.type_k,
13257+ params.type_v,
13258+ !cparams.flash_attn,
13259+ cparams.offload_kqv,
13260+ params.swa_full,
13261+ cparams.n_ctx,
13262+ cparams.n_seq_max,
13263+ cparams.n_ubatch,
13264+ padding);
13265+ } else {
13266+ GGML_ASSERT(!hparams.is_swa_any());
13267+
13268+ res = new llama_kv_cache_unified(
13269+ *this,
13270+ nullptr,
13271+ params.type_k,
13272+ params.type_v,
13273+ !cparams.flash_attn,
13274+ cparams.offload_kqv,
13275+ cparams.n_ctx,
13276+ cparams.n_seq_max,
13277+ padding,
13278+ hparams.n_swa,
13279+ hparams.swa_type);
13280+ }
1326513281 }
1326613282 }
1326713283 }
0 commit comments