@@ -13196,6 +13196,8 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
1319613196 llama_memory_i * res;
1319713197
1319813198 switch (arch) {
13199+ // Models that need specific instantiation should be handled in the
13200+ // switch statement
1319913201 case LLM_ARCH_BERT:
1320013202 case LLM_ARCH_JINA_BERT_V2:
1320113203 case LLM_ARCH_NOMIC_BERT:
@@ -13204,58 +13206,108 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
1320413206 {
1320513207 res = nullptr;
1320613208 } break;
13207- case LLM_ARCH_MAMBA:
13208- case LLM_ARCH_RWKV6:
13209- case LLM_ARCH_RWKV6QWEN2:
13210- case LLM_ARCH_RWKV7:
13211- case LLM_ARCH_ARWKV7:
13212- {
13213- res = new llama_kv_cache_recurrent(
13214- *this,
13215- nullptr,
13216- GGML_TYPE_F32,
13217- GGML_TYPE_F32,
13218- cparams.offload_kqv,
13219- std::max((uint32_t) 1, cparams.n_seq_max),
13220- cparams.n_seq_max);
13221- } break;
13209+ // Models that need standard caching should rely on recurrent/hybrid
13210+ // checks
1322213211 default:
1322313212 {
13224- const auto padding = llama_kv_cache_unified::get_padding(cparams);
13225-
13226- cparams.n_ctx = GGML_PAD(cparams.n_ctx, padding);
13227-
13228- LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx);
13229-
13230- if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
13231- GGML_ASSERT(hparams.is_swa_any());
13213+ if (llm_arch_is_hybrid(arch)) {
13214+ // make vectors of recurrent and non-recurrent layer indices
13215+ std::vector<size_t> recurrent_layers;
13216+ std::vector<size_t> unified_layers;
13217+ for (auto il = 0u; il < hparams.n_layer; ++il) {
13218+ if (hparams.recurrent_layer(il)) {
13219+ recurrent_layers.push_back(il);
13220+ } else {
13221+ unified_layers.push_back(il);
13222+ }
13223+ }
1323213224
13233- res = new llama_kv_cache_unified_iswa(
13234- *this,
13235- params.type_k,
13236- params.type_v,
13237- !cparams.flash_attn,
13238- cparams.offload_kqv,
13239- params.swa_full,
13240- cparams.n_ctx,
13241- cparams.n_seq_max,
13242- cparams.n_batch,
13243- padding);
13244- } else {
13245- GGML_ASSERT(!hparams.is_swa_any());
13225+ const auto padding = llama_kv_cache_unified::get_padding(cparams);
13226+ cparams.n_ctx = GGML_PAD(cparams.n_ctx, padding);
13227+ LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx);
13228+
13229+ // initialize the children
13230+ std::vector<llama_kv_cache_hybrid::child_cache> children;
13231+ children.emplace_back(
13232+ std::unique_ptr<llama_kv_cache>(
13233+ new llama_kv_cache_recurrent(
13234+ *this,
13235+ [&](int32_t il) {
13236+ return hparams.recurrent_layer(il);
13237+ },
13238+ GGML_TYPE_F32,
13239+ GGML_TYPE_F32,
13240+ cparams.offload_kqv,
13241+ std::max((uint32_t) 1, cparams.n_seq_max),
13242+ cparams.n_seq_max)
13243+ ),
13244+ std::move(recurrent_layers)
13245+ );
13246+ children.emplace_back(
13247+ std::unique_ptr<llama_kv_cache>(
13248+ new llama_kv_cache_unified(
13249+ *this,
13250+ [&](int32_t il) {
13251+ return ! hparams.recurrent_layer(il);
13252+ },
13253+ params.type_k,
13254+ params.type_v,
13255+ !cparams.flash_attn,
13256+ cparams.offload_kqv,
13257+ cparams.n_ctx,
13258+ cparams.n_seq_max,
13259+ padding,
13260+ hparams.n_swa,
13261+ hparams.swa_type)
13262+ ),
13263+ std::move(unified_layers)
13264+ );
1324613265
13247- res = new llama_kv_cache_unified(
13266+ // initialize the hybrid cache with both children
13267+ res = new llama_kv_cache_hybrid(hparams, std::move(children));
13268+ } else if (llm_arch_is_recurrent(arch)) {
13269+ res = new llama_kv_cache_recurrent(
1324813270 *this,
1324913271 nullptr,
13250- params.type_k,
13251- params.type_v,
13252- !cparams.flash_attn,
13272+ GGML_TYPE_F32,
13273+ GGML_TYPE_F32,
1325313274 cparams.offload_kqv,
13254- cparams.n_ctx,
13255- cparams.n_seq_max,
13256- padding,
13257- hparams.n_swa,
13258- hparams.swa_type);
13275+ std::max((uint32_t) 1, cparams.n_seq_max),
13276+ cparams.n_seq_max
13277+ );
13278+ } else {
13279+ const auto padding = llama_kv_cache_unified::get_padding(cparams);
13280+
13281+ cparams.n_ctx = GGML_PAD(cparams.n_ctx, padding);
13282+
13283+ LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx);
13284+
13285+ if (hparams.n_swa > 0) {
13286+ res = new llama_kv_cache_unified_iswa(
13287+ *this,
13288+ params.type_k,
13289+ params.type_v,
13290+ !cparams.flash_attn,
13291+ cparams.offload_kqv,
13292+ cparams.n_ctx,
13293+ params.swa_full,
13294+ cparams.n_seq_max,
13295+ cparams.n_batch,
13296+ padding);
13297+ } else {
13298+ res = new llama_kv_cache_unified(
13299+ *this,
13300+ nullptr,
13301+ params.type_k,
13302+ params.type_v,
13303+ !cparams.flash_attn,
13304+ cparams.offload_kqv,
13305+ cparams.n_ctx,
13306+ cparams.n_seq_max,
13307+ padding,
13308+ hparams.n_swa,
13309+ hparams.swa_type);
13310+ }
1325913311 }
1326013312 }
1326113313 }
0 commit comments