Skip to content

Commit 423d890

Browse files
committed
fix(nemotron-h): Fix KV cache over-allocation for hybrid architecture
- Add custom cache initialization filters for LLM_ARCH_NEMOTRON_H - Attention cache only allocated for layers 14, 21, 30, 39 (attention layers) - Recurrent cache only allocated for SSM layers using is_recurrent() - Reduces KV cache memory usage from 264MB (29 layers) to 64MB (4 layers) - Implements proper Mamba2-style SSM with x/z gating and SiLU activation - Resolves infinite hang issue during token generation
1 parent cce8cb1 commit 423d890

File tree

1 file changed

+16
-2
lines changed

1 file changed

+16
-2
lines changed

src/llama-model.cpp

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18364,8 +18364,22 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
1836418364
/* n_seq_max */ cparams.n_seq_max,
1836518365
/* offload */ cparams.offload_kqv,
1836618366
/* unified */ cparams.kv_unified,
18367-
/* filter_attn */ (arch == LLM_ARCH_FALCON_H1) ? [&](int32_t) { return true; } : (llama_memory_hybrid::layer_filter_cb)nullptr,
18368-
/* filter_recr */ (arch == LLM_ARCH_FALCON_H1) ? [&](int32_t) { return true; } : (llama_memory_hybrid::layer_filter_cb)nullptr);
18367+
/* filter_attn */ (arch == LLM_ARCH_FALCON_H1 || arch == LLM_ARCH_NEMOTRON_H) ?
18368+
[&](int32_t il) {
18369+
// For NEMOTRON_H: only allocate cache for attention layers (14, 21, 30, 39)
18370+
if (arch == LLM_ARCH_NEMOTRON_H) {
18371+
return (il == 14 || il == 21 || il == 30 || il == 39);
18372+
}
18373+
return true; // FALCON_H1 case
18374+
} : (llama_memory_hybrid::layer_filter_cb)nullptr,
18375+
/* filter_recr */ (arch == LLM_ARCH_FALCON_H1 || arch == LLM_ARCH_NEMOTRON_H) ?
18376+
[&](int32_t il) {
18377+
// For NEMOTRON_H: allocate recurrent state for SSM layers (non-attention, non-MLP)
18378+
if (arch == LLM_ARCH_NEMOTRON_H) {
18379+
return hparams.is_recurrent(il);
18380+
}
18381+
return true; // FALCON_H1 case
18382+
} : (llama_memory_hybrid::layer_filter_cb)nullptr);
1836918383
} else {
1837018384
const auto padding = llama_kv_cache::get_padding(cparams);
1837118385

0 commit comments

Comments
 (0)