@@ -13192,6 +13192,8 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
1319213192 llama_memory_i * res;
1319313193
1319413194 switch (arch) {
13195+ // Models that need specific instantiation should be handled in the
13196+ // switch statement
1319513197 case LLM_ARCH_BERT:
1319613198 case LLM_ARCH_JINA_BERT_V2:
1319713199 case LLM_ARCH_NOMIC_BERT:
@@ -13200,58 +13202,108 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
1320013202 {
1320113203 res = nullptr;
1320213204 } break;
13203- case LLM_ARCH_MAMBA:
13204- case LLM_ARCH_RWKV6:
13205- case LLM_ARCH_RWKV6QWEN2:
13206- case LLM_ARCH_RWKV7:
13207- case LLM_ARCH_ARWKV7:
13208- {
13209- res = new llama_kv_cache_recurrent(
13210- *this,
13211- nullptr,
13212- GGML_TYPE_F32,
13213- GGML_TYPE_F32,
13214- cparams.offload_kqv,
13215- std::max((uint32_t) 1, cparams.n_seq_max),
13216- cparams.n_seq_max);
13217- } break;
13205+ // Models that need standard caching should rely on recurrent/hybrid
13206+ // checks
1321813207 default:
1321913208 {
13220- const auto padding = llama_kv_cache_unified::get_padding(cparams);
13221-
13222- cparams.n_ctx = GGML_PAD(cparams.n_ctx, padding);
13223-
13224- LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx);
13225-
13226- if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
13227- GGML_ASSERT(hparams.is_swa_any());
13209+ if (llm_arch_is_hybrid(arch)) {
13210+ // make vectors of recurrent and non-recurrent layer indices
13211+ std::vector<size_t> recurrent_layers;
13212+ std::vector<size_t> unified_layers;
13213+ for (auto il = 0u; il < hparams.n_layer; ++il) {
13214+ if (hparams.recurrent_layer(il)) {
13215+ recurrent_layers.push_back(il);
13216+ } else {
13217+ unified_layers.push_back(il);
13218+ }
13219+ }
1322813220
13229- res = new llama_kv_cache_unified_iswa(
13230- *this,
13231- params.type_k,
13232- params.type_v,
13233- !cparams.flash_attn,
13234- cparams.offload_kqv,
13235- params.swa_full,
13236- cparams.n_ctx,
13237- cparams.n_seq_max,
13238- cparams.n_batch,
13239- padding);
13240- } else {
13241- GGML_ASSERT(!hparams.is_swa_any());
13221+ const auto padding = llama_kv_cache_unified::get_padding(cparams);
13222+ cparams.n_ctx = GGML_PAD(cparams.n_ctx, padding);
13223+ LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx);
13224+
13225+ // initialize the children
13226+ std::vector<llama_kv_cache_hybrid::child_cache> children;
13227+ children.emplace_back(
13228+ std::unique_ptr<llama_kv_cache>(
13229+ new llama_kv_cache_recurrent(
13230+ *this,
13231+ [&](int32_t il) {
13232+ return hparams.recurrent_layer(il);
13233+ },
13234+ GGML_TYPE_F32,
13235+ GGML_TYPE_F32,
13236+ cparams.offload_kqv,
13237+ std::max((uint32_t) 1, cparams.n_seq_max),
13238+ cparams.n_seq_max)
13239+ ),
13240+ std::move(recurrent_layers)
13241+ );
13242+ children.emplace_back(
13243+ std::unique_ptr<llama_kv_cache>(
13244+ new llama_kv_cache_unified(
13245+ *this,
13246+ [&](int32_t il) {
13247+ return ! hparams.recurrent_layer(il);
13248+ },
13249+ params.type_k,
13250+ params.type_v,
13251+ !cparams.flash_attn,
13252+ cparams.offload_kqv,
13253+ cparams.n_ctx,
13254+ cparams.n_seq_max,
13255+ padding,
13256+ hparams.n_swa,
13257+ hparams.swa_type)
13258+ ),
13259+ std::move(unified_layers)
13260+ );
1324213261
13243- res = new llama_kv_cache_unified(
13262+ // initialize the hybrid cache with both children
13263+ res = new llama_kv_cache_hybrid(hparams, std::move(children));
13264+ } else if (llm_arch_is_recurrent(arch)) {
13265+ res = new llama_kv_cache_recurrent(
1324413266 *this,
1324513267 nullptr,
13246- params.type_k,
13247- params.type_v,
13248- !cparams.flash_attn,
13268+ GGML_TYPE_F32,
13269+ GGML_TYPE_F32,
1324913270 cparams.offload_kqv,
13250- cparams.n_ctx,
13251- cparams.n_seq_max,
13252- padding,
13253- hparams.n_swa,
13254- hparams.swa_type);
13271+ std::max((uint32_t) 1, cparams.n_seq_max),
13272+ cparams.n_seq_max
13273+ );
13274+ } else {
13275+ const auto padding = llama_kv_cache_unified::get_padding(cparams);
13276+
13277+ cparams.n_ctx = GGML_PAD(cparams.n_ctx, padding);
13278+
13279+ LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx);
13280+
13281+ if (hparams.n_swa > 0) {
13282+ res = new llama_kv_cache_unified_iswa(
13283+ *this,
13284+ params.type_k,
13285+ params.type_v,
13286+ !cparams.flash_attn,
13287+ cparams.offload_kqv,
13288+ cparams.n_ctx,
13289+ params.swa_full,
13290+ cparams.n_seq_max,
13291+ cparams.n_batch,
13292+ padding);
13293+ } else {
13294+ res = new llama_kv_cache_unified(
13295+ *this,
13296+ nullptr,
13297+ params.type_k,
13298+ params.type_v,
13299+ !cparams.flash_attn,
13300+ cparams.offload_kqv,
13301+ cparams.n_ctx,
13302+ cparams.n_seq_max,
13303+ padding,
13304+ hparams.n_swa,
13305+ hparams.swa_type);
13306+ }
1325513307 }
1325613308 }
1325713309 }
0 commit comments