99#include "llama-kv-cache-unified.h"
1010#include "llama-kv-cache-unified-iswa.h"
1111#include "llama-kv-cache-recurrent.h"
12+ #include "llama-kv-cache-hybrid-recurrent.h"
1213
1314#include "ggml-cpp.h"
1415
@@ -13197,6 +13198,8 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
1319713198 llama_memory_i * res;
1319813199
1319913200 switch (arch) {
13201+ // Models that need specific instantiation should be handled in the
13202+ // switch statement
1320013203 case LLM_ARCH_BERT:
1320113204 case LLM_ARCH_JINA_BERT_V2:
1320213205 case LLM_ARCH_NOMIC_BERT:
@@ -13205,58 +13208,71 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
1320513208 {
1320613209 res = nullptr;
1320713210 } break;
13208- case LLM_ARCH_MAMBA:
13209- case LLM_ARCH_RWKV6:
13210- case LLM_ARCH_RWKV6QWEN2:
13211- case LLM_ARCH_RWKV7:
13212- case LLM_ARCH_ARWKV7:
13213- {
13214- res = new llama_kv_cache_recurrent(
13215- *this,
13216- nullptr,
13217- GGML_TYPE_F32,
13218- GGML_TYPE_F32,
13219- cparams.offload_kqv,
13220- std::max((uint32_t) 1, cparams.n_seq_max),
13221- cparams.n_seq_max);
13222- } break;
13211+ // Models that need standard caching should rely on recurrent/hybrid
13212+ // checks
1322313213 default:
1322413214 {
13225- const auto padding = llama_kv_cache_unified::get_padding(cparams);
13226-
13227- cparams.n_ctx = GGML_PAD(cparams.n_ctx, padding);
13228-
13229- LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx);
13230-
13231- if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
13232- GGML_ASSERT(hparams.is_swa_any());
13233-
13234- res = new llama_kv_cache_unified_iswa(
13235- *this,
13236- params.type_k,
13237- params.type_v,
13238- !cparams.flash_attn,
13239- cparams.offload_kqv,
13240- params.swa_full,
13241- cparams.n_ctx,
13242- cparams.n_seq_max,
13243- cparams.n_ubatch,
13244- padding);
13245- } else {
13246- GGML_ASSERT(!hparams.is_swa_any());
13247-
13248- res = new llama_kv_cache_unified(
13215+ if (llm_arch_is_recurrent(arch)) {
13216+ res = new llama_kv_cache_recurrent(
1324913217 *this,
1325013218 nullptr,
13251- params.type_k,
13252- params.type_v,
13253- !cparams.flash_attn,
13219+ GGML_TYPE_F32,
13220+ GGML_TYPE_F32,
1325413221 cparams.offload_kqv,
13255- cparams.n_ctx,
13256- cparams.n_seq_max,
13257- padding,
13258- hparams.n_swa,
13259- hparams.swa_type);
13222+ std::max((uint32_t) 1, cparams.n_seq_max),
13223+ cparams.n_seq_max);
13224+ } else if (llm_arch_is_hybrid_recurrent(arch)) {
13225+ res = new llama_kv_cache_hybrid_recurrent(
13226+ /* model */ *this,
13227+ /* attn_type_k */ params.type_k,
13228+ /* attn_type_v */ params.type_v,
13229+ /* attn_v_trans */ !cparams.flash_attn,
13230+ /* attn_kv_size */ cparams.n_ctx,
13231+ /* attn_n_pad */ llama_kv_cache_unified::get_padding(cparams),
13232+ /* attn_n_swa */ hparams.n_swa,
13233+ /* attn_swa_type */ hparams.swa_type,
13234+ /* recurrent_type_k */ GGML_TYPE_F32,
13235+ /* recurrent_type_v */ GGML_TYPE_F32,
13236+ /* recurrent_kv_size */ std::max((uint32_t) 1, cparams.n_seq_max),
13237+ /* n_seq_max */ cparams.n_seq_max,
13238+ /* offload */ cparams.offload_kqv);
13239+ } else {
13240+ const auto padding = llama_kv_cache_unified::get_padding(cparams);
13241+
13242+ cparams.n_ctx = GGML_PAD(cparams.n_ctx, padding);
13243+
13244+ LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx);
13245+
13246+ if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
13247+ GGML_ASSERT(hparams.is_swa_any());
13248+
13249+ res = new llama_kv_cache_unified_iswa(
13250+ *this,
13251+ params.type_k,
13252+ params.type_v,
13253+ !cparams.flash_attn,
13254+ cparams.offload_kqv,
13255+ params.swa_full,
13256+ cparams.n_ctx,
13257+ cparams.n_seq_max,
13258+ cparams.n_ubatch,
13259+ padding);
13260+ } else {
13261+ GGML_ASSERT(!hparams.is_swa_any());
13262+
13263+ res = new llama_kv_cache_unified(
13264+ *this,
13265+ nullptr,
13266+ params.type_k,
13267+ params.type_v,
13268+ !cparams.flash_attn,
13269+ cparams.offload_kqv,
13270+ cparams.n_ctx,
13271+ cparams.n_seq_max,
13272+ padding,
13273+ hparams.n_swa,
13274+ hparams.swa_type);
13275+ }
1326013276 }
1326113277 }
1326213278 }
0 commit comments