Skip to content

Commit 0cf4d12

Browse files
committed
feat: Instantiate hybrid cache for hybrid models (currently none)
This includes a slight architectural change where create_memory now only uses model architectures in the switch statement if their required cache type is not handled by llm_arch_is_[recurrent|hybrid]. Branch: HybridCache Signed-off-by: Gabe Goodhart <[email protected]>
1 parent 0701efe commit 0cf4d12

File tree

1 file changed

+97
-45
lines changed

1 file changed

+97
-45
lines changed

src/llama-model.cpp

Lines changed: 97 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -13192,6 +13192,8 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
1319213192
llama_memory_i * res;
1319313193

1319413194
switch (arch) {
13195+
// Models that need specific instantiation should be handled in the
13196+
// switch statement
1319513197
case LLM_ARCH_BERT:
1319613198
case LLM_ARCH_JINA_BERT_V2:
1319713199
case LLM_ARCH_NOMIC_BERT:
@@ -13200,58 +13202,108 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
1320013202
{
1320113203
res = nullptr;
1320213204
} break;
13203-
case LLM_ARCH_MAMBA:
13204-
case LLM_ARCH_RWKV6:
13205-
case LLM_ARCH_RWKV6QWEN2:
13206-
case LLM_ARCH_RWKV7:
13207-
case LLM_ARCH_ARWKV7:
13208-
{
13209-
res = new llama_kv_cache_recurrent(
13210-
*this,
13211-
nullptr,
13212-
GGML_TYPE_F32,
13213-
GGML_TYPE_F32,
13214-
cparams.offload_kqv,
13215-
std::max((uint32_t) 1, cparams.n_seq_max),
13216-
cparams.n_seq_max);
13217-
} break;
13205+
// Models that need standard caching should rely on recurrent/hybrid
13206+
// checks
1321813207
default:
1321913208
{
13220-
const auto padding = llama_kv_cache_unified::get_padding(cparams);
13221-
13222-
cparams.n_ctx = GGML_PAD(cparams.n_ctx, padding);
13223-
13224-
LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx);
13225-
13226-
if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
13227-
GGML_ASSERT(hparams.is_swa_any());
13209+
if (llm_arch_is_hybrid(arch)) {
13210+
// make vectors of recurrent and non-recurrent layer indices
13211+
std::vector<size_t> recurrent_layers;
13212+
std::vector<size_t> unified_layers;
13213+
for (auto il = 0u; il < hparams.n_layer; ++il) {
13214+
if (hparams.recurrent_layer(il)) {
13215+
recurrent_layers.push_back(il);
13216+
} else {
13217+
unified_layers.push_back(il);
13218+
}
13219+
}
1322813220

13229-
res = new llama_kv_cache_unified_iswa(
13230-
*this,
13231-
params.type_k,
13232-
params.type_v,
13233-
!cparams.flash_attn,
13234-
cparams.offload_kqv,
13235-
params.swa_full,
13236-
cparams.n_ctx,
13237-
cparams.n_seq_max,
13238-
cparams.n_batch,
13239-
padding);
13240-
} else {
13241-
GGML_ASSERT(!hparams.is_swa_any());
13221+
const auto padding = llama_kv_cache_unified::get_padding(cparams);
13222+
cparams.n_ctx = GGML_PAD(cparams.n_ctx, padding);
13223+
LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx);
13224+
13225+
// initialize the children
13226+
std::vector<llama_kv_cache_hybrid::child_cache> children;
13227+
children.emplace_back(
13228+
std::unique_ptr<llama_kv_cache>(
13229+
new llama_kv_cache_recurrent(
13230+
*this,
13231+
[&](int32_t il) {
13232+
return hparams.recurrent_layer(il);
13233+
},
13234+
GGML_TYPE_F32,
13235+
GGML_TYPE_F32,
13236+
cparams.offload_kqv,
13237+
std::max((uint32_t) 1, cparams.n_seq_max),
13238+
cparams.n_seq_max)
13239+
),
13240+
std::move(recurrent_layers)
13241+
);
13242+
children.emplace_back(
13243+
std::unique_ptr<llama_kv_cache>(
13244+
new llama_kv_cache_unified(
13245+
*this,
13246+
[&](int32_t il) {
13247+
return ! hparams.recurrent_layer(il);
13248+
},
13249+
params.type_k,
13250+
params.type_v,
13251+
!cparams.flash_attn,
13252+
cparams.offload_kqv,
13253+
cparams.n_ctx,
13254+
cparams.n_seq_max,
13255+
padding,
13256+
hparams.n_swa,
13257+
hparams.swa_type)
13258+
),
13259+
std::move(unified_layers)
13260+
);
1324213261

13243-
res = new llama_kv_cache_unified(
13262+
// initialize the hybrid cache with both children
13263+
res = new llama_kv_cache_hybrid(hparams, std::move(children));
13264+
} else if (llm_arch_is_recurrent(arch)) {
13265+
res = new llama_kv_cache_recurrent(
1324413266
*this,
1324513267
nullptr,
13246-
params.type_k,
13247-
params.type_v,
13248-
!cparams.flash_attn,
13268+
GGML_TYPE_F32,
13269+
GGML_TYPE_F32,
1324913270
cparams.offload_kqv,
13250-
cparams.n_ctx,
13251-
cparams.n_seq_max,
13252-
padding,
13253-
hparams.n_swa,
13254-
hparams.swa_type);
13271+
std::max((uint32_t) 1, cparams.n_seq_max),
13272+
cparams.n_seq_max
13273+
);
13274+
} else {
13275+
const auto padding = llama_kv_cache_unified::get_padding(cparams);
13276+
13277+
cparams.n_ctx = GGML_PAD(cparams.n_ctx, padding);
13278+
13279+
LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx);
13280+
13281+
if (hparams.n_swa > 0) {
13282+
res = new llama_kv_cache_unified_iswa(
13283+
*this,
13284+
params.type_k,
13285+
params.type_v,
13286+
!cparams.flash_attn,
13287+
cparams.offload_kqv,
13288+
cparams.n_ctx,
13289+
params.swa_full,
13290+
cparams.n_seq_max,
13291+
cparams.n_batch,
13292+
padding);
13293+
} else {
13294+
res = new llama_kv_cache_unified(
13295+
*this,
13296+
nullptr,
13297+
params.type_k,
13298+
params.type_v,
13299+
!cparams.flash_attn,
13300+
cparams.offload_kqv,
13301+
cparams.n_ctx,
13302+
cparams.n_seq_max,
13303+
padding,
13304+
hparams.n_swa,
13305+
hparams.swa_type);
13306+
}
1325513307
}
1325613308
}
1325713309
}

0 commit comments

Comments
 (0)