Skip to content

Commit 11a3811

Browse files
authored
memory : handle kv_unified for hybrid models (#15050)
1 parent 97366dc commit 11a3811

File tree

3 files changed

+4
-1
lines changed

3 files changed

+4
-1
lines changed

src/llama-memory-hybrid.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ llama_memory_hybrid::llama_memory_hybrid(
2525
/* common */
2626
uint32_t n_seq_max,
2727
bool offload,
28+
bool unified,
2829
/* layer filters */
2930
layer_filter_cb && filter_attn,
3031
layer_filter_cb && filter_recr) :
@@ -38,7 +39,7 @@ llama_memory_hybrid::llama_memory_hybrid(
3839
type_v,
3940
v_trans,
4041
offload,
41-
1,
42+
unified,
4243
kv_size,
4344
n_seq_max,
4445
n_pad,

src/llama-memory-hybrid.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ class llama_memory_hybrid : public llama_memory_i {
3939
/* common */
4040
uint32_t n_seq_max,
4141
bool offload,
42+
bool unified,
4243
/* layer filters */
4344
layer_filter_cb && filter_attn = nullptr,
4445
layer_filter_cb && filter_recr = nullptr);

src/llama-model.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17598,6 +17598,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
1759817598
/* recurrent_kv_size */ std::max((uint32_t) 1, cparams.n_seq_max),
1759917599
/* n_seq_max */ cparams.n_seq_max,
1760017600
/* offload */ cparams.offload_kqv,
17601+
/* unified */ cparams.kv_unified,
1760117602
/* filter_attn */ (arch == LLM_ARCH_FALCON_H1) ? [&](int32_t) { return true; } : (llama_memory_hybrid::layer_filter_cb)nullptr,
1760217603
/* filter_recr */ (arch == LLM_ARCH_FALCON_H1) ? [&](int32_t) { return true; } : (llama_memory_hybrid::layer_filter_cb)nullptr);
1760317604
} else {

0 commit comments

Comments
 (0)