Skip to content

Commit db39e8b

Browse files
committed
feat: Instantiate hybrid cache for hybrid models (currently none)
This includes a slight architectural change where create_memory now only uses model architectures in the switch statement if their required cache type is not handled by llm_arch_is_[recurrent|hybrid]. Branch: HybridCache Signed-off-by: Gabe Goodhart <[email protected]>
1 parent cb09814 commit db39e8b

File tree

1 file changed

+97
-45
lines changed

1 file changed

+97
-45
lines changed

src/llama-model.cpp

Lines changed: 97 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -13196,6 +13196,8 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
1319613196
llama_memory_i * res;
1319713197

1319813198
switch (arch) {
13199+
// Models that need specific instantiation should be handled in the
13200+
// switch statement
1319913201
case LLM_ARCH_BERT:
1320013202
case LLM_ARCH_JINA_BERT_V2:
1320113203
case LLM_ARCH_NOMIC_BERT:
@@ -13204,58 +13206,108 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
1320413206
{
1320513207
res = nullptr;
1320613208
} break;
13207-
case LLM_ARCH_MAMBA:
13208-
case LLM_ARCH_RWKV6:
13209-
case LLM_ARCH_RWKV6QWEN2:
13210-
case LLM_ARCH_RWKV7:
13211-
case LLM_ARCH_ARWKV7:
13212-
{
13213-
res = new llama_kv_cache_recurrent(
13214-
*this,
13215-
nullptr,
13216-
GGML_TYPE_F32,
13217-
GGML_TYPE_F32,
13218-
cparams.offload_kqv,
13219-
std::max((uint32_t) 1, cparams.n_seq_max),
13220-
cparams.n_seq_max);
13221-
} break;
13209+
// Models that need standard caching should rely on recurrent/hybrid
13210+
// checks
1322213211
default:
1322313212
{
13224-
const auto padding = llama_kv_cache_unified::get_padding(cparams);
13225-
13226-
cparams.n_ctx = GGML_PAD(cparams.n_ctx, padding);
13227-
13228-
LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx);
13229-
13230-
if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
13231-
GGML_ASSERT(hparams.is_swa_any());
13213+
if (llm_arch_is_hybrid(arch)) {
13214+
// make vectors of recurrent and non-recurrent layer indices
13215+
std::vector<size_t> recurrent_layers;
13216+
std::vector<size_t> unified_layers;
13217+
for (auto il = 0u; il < hparams.n_layer; ++il) {
13218+
if (hparams.recurrent_layer(il)) {
13219+
recurrent_layers.push_back(il);
13220+
} else {
13221+
unified_layers.push_back(il);
13222+
}
13223+
}
1323213224

13233-
res = new llama_kv_cache_unified_iswa(
13234-
*this,
13235-
params.type_k,
13236-
params.type_v,
13237-
!cparams.flash_attn,
13238-
cparams.offload_kqv,
13239-
params.swa_full,
13240-
cparams.n_ctx,
13241-
cparams.n_seq_max,
13242-
cparams.n_batch,
13243-
padding);
13244-
} else {
13245-
GGML_ASSERT(!hparams.is_swa_any());
13225+
const auto padding = llama_kv_cache_unified::get_padding(cparams);
13226+
cparams.n_ctx = GGML_PAD(cparams.n_ctx, padding);
13227+
LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx);
13228+
13229+
// initialize the children
13230+
std::vector<llama_kv_cache_hybrid::child_cache> children;
13231+
children.emplace_back(
13232+
std::unique_ptr<llama_kv_cache>(
13233+
new llama_kv_cache_recurrent(
13234+
*this,
13235+
[&](int32_t il) {
13236+
return hparams.recurrent_layer(il);
13237+
},
13238+
GGML_TYPE_F32,
13239+
GGML_TYPE_F32,
13240+
cparams.offload_kqv,
13241+
std::max((uint32_t) 1, cparams.n_seq_max),
13242+
cparams.n_seq_max)
13243+
),
13244+
std::move(recurrent_layers)
13245+
);
13246+
children.emplace_back(
13247+
std::unique_ptr<llama_kv_cache>(
13248+
new llama_kv_cache_unified(
13249+
*this,
13250+
[&](int32_t il) {
13251+
return ! hparams.recurrent_layer(il);
13252+
},
13253+
params.type_k,
13254+
params.type_v,
13255+
!cparams.flash_attn,
13256+
cparams.offload_kqv,
13257+
cparams.n_ctx,
13258+
cparams.n_seq_max,
13259+
padding,
13260+
hparams.n_swa,
13261+
hparams.swa_type)
13262+
),
13263+
std::move(unified_layers)
13264+
);
1324613265

13247-
res = new llama_kv_cache_unified(
13266+
// initialize the hybrid cache with both children
13267+
res = new llama_kv_cache_hybrid(hparams, std::move(children));
13268+
} else if (llm_arch_is_recurrent(arch)) {
13269+
res = new llama_kv_cache_recurrent(
1324813270
*this,
1324913271
nullptr,
13250-
params.type_k,
13251-
params.type_v,
13252-
!cparams.flash_attn,
13272+
GGML_TYPE_F32,
13273+
GGML_TYPE_F32,
1325313274
cparams.offload_kqv,
13254-
cparams.n_ctx,
13255-
cparams.n_seq_max,
13256-
padding,
13257-
hparams.n_swa,
13258-
hparams.swa_type);
13275+
std::max((uint32_t) 1, cparams.n_seq_max),
13276+
cparams.n_seq_max
13277+
);
13278+
} else {
13279+
const auto padding = llama_kv_cache_unified::get_padding(cparams);
13280+
13281+
cparams.n_ctx = GGML_PAD(cparams.n_ctx, padding);
13282+
13283+
LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx);
13284+
13285+
if (hparams.n_swa > 0) {
13286+
res = new llama_kv_cache_unified_iswa(
13287+
*this,
13288+
params.type_k,
13289+
params.type_v,
13290+
!cparams.flash_attn,
13291+
cparams.offload_kqv,
13292+
cparams.n_ctx,
13293+
params.swa_full,
13294+
cparams.n_seq_max,
13295+
cparams.n_batch,
13296+
padding);
13297+
} else {
13298+
res = new llama_kv_cache_unified(
13299+
*this,
13300+
nullptr,
13301+
params.type_k,
13302+
params.type_v,
13303+
!cparams.flash_attn,
13304+
cparams.offload_kqv,
13305+
cparams.n_ctx,
13306+
cparams.n_seq_max,
13307+
padding,
13308+
hparams.n_swa,
13309+
hparams.swa_type);
13310+
}
1325913311
}
1326013312
}
1326113313
}

0 commit comments

Comments
 (0)