@@ -49,50 +49,6 @@ llama_kv_cache_hybrid_recurrent::llama_kv_cache_hybrid_recurrent(
4949 n_seq_max
5050 )) {}
5151
52- void llama_kv_cache_hybrid_recurrent::clear () {
53- kv_attn ->clear ();
54- kv_recurrent->clear ();
55- }
56-
57- bool llama_kv_cache_hybrid_recurrent::seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
58- // Try removing from the recurrent cache first since it may fail. If it does
59- // fail, the cache will not have been mutated.
60- if (!kv_recurrent->seq_rm (seq_id, p0, p1)) {
61- return false ;
62- }
63- return kv_attn->seq_rm (seq_id, p0, p1);
64- }
65-
66- void llama_kv_cache_hybrid_recurrent::seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
67- kv_attn ->seq_cp (seq_id_src, seq_id_dst, p0, p1);
68- kv_recurrent->seq_cp (seq_id_src, seq_id_dst, p0, p1);
69- }
70-
71- void llama_kv_cache_hybrid_recurrent::seq_keep (llama_seq_id seq_id) {
72- kv_attn ->seq_keep (seq_id);
73- kv_recurrent->seq_keep (seq_id);
74- }
75-
76- void llama_kv_cache_hybrid_recurrent::seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
77- kv_attn->seq_add (seq_id, p0, p1, shift);
78- kv_recurrent->seq_add (seq_id, p0, p1, shift);
79- }
80-
81- void llama_kv_cache_hybrid_recurrent::seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
82- kv_attn ->seq_div (seq_id, p0, p1, d);
83- kv_recurrent->seq_div (seq_id, p0, p1, d);
84- }
85-
86- llama_pos llama_kv_cache_hybrid_recurrent::seq_pos_min (llama_seq_id seq_id) const {
87- // the min of the total cache is the max of the two caches' min values
88- return std::max (kv_attn->seq_pos_min (seq_id), kv_recurrent->seq_pos_min (seq_id));
89- }
90-
91- llama_pos llama_kv_cache_hybrid_recurrent::seq_pos_max (llama_seq_id seq_id) const {
92- // the max of the total cache is the min of the two caches' max values
93- return std::min (kv_attn->seq_pos_max (seq_id), kv_recurrent->seq_pos_max (seq_id));
94- }
95-
9652llama_memory_state_ptr llama_kv_cache_hybrid_recurrent::init_batch (const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled, bool logits_all) {
9753
9854 // since this includes a recurrent cache, we cannot use split_simple
@@ -135,23 +91,59 @@ llama_memory_state_ptr llama_kv_cache_hybrid_recurrent::init_full() {
13591 return std::make_unique<llama_kv_cache_hybrid_recurrent_state>(this );
13692}
13793
138- bool llama_kv_cache_hybrid_recurrent::update (llama_context & lctx) {
139- bool res = false ;
94+ llama_memory_state_ptr llama_kv_cache_hybrid_recurrent::init_update (llama_context * lctx, bool optimize) {
95+ return std::make_unique<llama_kv_cache_hybrid_recurrent_state>(
96+ this ,
97+ static_cast <llama_kv_cache_unified_state *>( kv_attn ->init_update (lctx, optimize).release ()),
98+ static_cast <llama_kv_cache_recurrent_state *>(kv_recurrent->init_update (lctx, optimize).release ()));
99+ }
100+
101+ bool llama_kv_cache_hybrid_recurrent::get_can_shift () const {
102+ // Shifting is trivially supported for recurrent
103+ return kv_attn->get_can_shift ();
104+ }
105+ void llama_kv_cache_hybrid_recurrent::clear () {
106+ kv_attn ->clear ();
107+ kv_recurrent->clear ();
108+ }
140109
141- res = res | kv_attn ->update (lctx);
142- res = res | kv_recurrent->update (lctx);
110+ bool llama_kv_cache_hybrid_recurrent::seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
111+ // Try removing from the recurrent cache first since it may fail. If it does
112+ // fail, the cache will not have been mutated.
113+ if (!kv_recurrent->seq_rm (seq_id, p0, p1)) {
114+ return false ;
115+ }
116+ return kv_attn->seq_rm (seq_id, p0, p1);
117+ }
143118
144- return res;
119+ void llama_kv_cache_hybrid_recurrent::seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
120+ kv_attn ->seq_cp (seq_id_src, seq_id_dst, p0, p1);
121+ kv_recurrent->seq_cp (seq_id_src, seq_id_dst, p0, p1);
145122}
146123
147- void llama_kv_cache_hybrid_recurrent::defrag_sched ( float thold ) {
148- kv_attn ->defrag_sched (thold );
149- kv_recurrent->defrag_sched (thold );
124+ void llama_kv_cache_hybrid_recurrent::seq_keep (llama_seq_id seq_id ) {
125+ kv_attn ->seq_keep (seq_id );
126+ kv_recurrent->seq_keep (seq_id );
150127}
151128
152- bool llama_kv_cache_hybrid_recurrent::get_can_shift () const {
153- // Shifting is trivially supported for recurrent
154- return kv_attn->get_can_shift ();
129+ void llama_kv_cache_hybrid_recurrent::seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
130+ kv_attn->seq_add (seq_id, p0, p1, shift);
131+ kv_recurrent->seq_add (seq_id, p0, p1, shift);
132+ }
133+
134+ void llama_kv_cache_hybrid_recurrent::seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
135+ kv_attn ->seq_div (seq_id, p0, p1, d);
136+ kv_recurrent->seq_div (seq_id, p0, p1, d);
137+ }
138+
139+ llama_pos llama_kv_cache_hybrid_recurrent::seq_pos_min (llama_seq_id seq_id) const {
140+ // the min of the total cache is the max of the two caches' min values
141+ return std::max (kv_attn->seq_pos_min (seq_id), kv_recurrent->seq_pos_min (seq_id));
142+ }
143+
144+ llama_pos llama_kv_cache_hybrid_recurrent::seq_pos_max (llama_seq_id seq_id) const {
145+ // the max of the total cache is the min of the two caches' max values
146+ return std::min (kv_attn->seq_pos_max (seq_id), kv_recurrent->seq_pos_max (seq_id));
155147}
156148
157149void llama_kv_cache_hybrid_recurrent::state_write (llama_io_write_i & io, llama_seq_id seq_id) const {
@@ -173,13 +165,24 @@ llama_kv_cache_recurrent * llama_kv_cache_hybrid_recurrent::get_kv_recurrent() c
173165}
174166
175167llama_kv_cache_hybrid_recurrent_state::llama_kv_cache_hybrid_recurrent_state (llama_memory_status status)
176- : status(status), state_attn(status), state_recurrent(status) {}
168+ : status(status),
169+ state_attn(new llama_kv_cache_unified_state(status)),
170+ state_recurrent(new llama_kv_cache_recurrent_state(status)) {}
177171
178172llama_kv_cache_hybrid_recurrent_state::llama_kv_cache_hybrid_recurrent_state (llama_kv_cache_hybrid_recurrent * kv)
179173 : status(LLAMA_MEMORY_STATUS_SUCCESS),
180174 kv(kv),
181- state_attn(status, kv->get_kv_attn ()),
182- state_recurrent(status, kv->get_kv_recurrent ()) {}
175+ state_attn(new llama_kv_cache_unified_state(kv->get_kv_attn ())),
176+ state_recurrent(new llama_kv_cache_recurrent_state(status, kv->get_kv_recurrent ())) {}
177+
178+ llama_kv_cache_hybrid_recurrent_state::llama_kv_cache_hybrid_recurrent_state (
179+ llama_kv_cache_hybrid_recurrent * kv,
180+ llama_kv_cache_unified_state * state_unified,
181+ llama_kv_cache_recurrent_state * state_recurrent)
182+ : status(LLAMA_MEMORY_STATUS_SUCCESS),
183+ kv(kv),
184+ state_attn(state_unified),
185+ state_recurrent(state_recurrent) {}
183186
184187llama_kv_cache_hybrid_recurrent_state::llama_kv_cache_hybrid_recurrent_state (
185188 llama_kv_cache_hybrid_recurrent * kv,
@@ -194,8 +197,8 @@ llama_kv_cache_hybrid_recurrent_state::llama_kv_cache_hybrid_recurrent_state(
194197 // NOTE: these child states are only used as wrapper APIs for the
195198 // const methods, so we use the "init full" signature since the
196199 // actual state is not used.
197- state_attn(LLAMA_MEMORY_STATUS_SUCCESS, kv->get_kv_attn ()),
198- state_recurrent(LLAMA_MEMORY_STATUS_SUCCESS, kv->get_kv_recurrent ()) {}
200+ state_attn(new llama_kv_cache_unified_state( kv->get_kv_attn () )),
201+ state_recurrent(new llama_kv_cache_recurrent_state( LLAMA_MEMORY_STATUS_SUCCESS, kv->get_kv_recurrent () )) {}
199202
200203
201204bool llama_kv_cache_hybrid_recurrent_state::next () {
@@ -232,10 +235,10 @@ const llama_ubatch & llama_kv_cache_hybrid_recurrent_state::get_ubatch() const {
232235 return ubatches[i_next];
233236}
234237
235- const llama_kv_cache_unified_state * llama_kv_cache_hybrid_recurrent_state::get_state_attn () const {
236- return & state_attn;
238+ const llama_kv_cache_unified_state * llama_kv_cache_hybrid_recurrent_state::get_state_attn () const {
239+ return state_attn. get () ;
237240}
238241
239242const llama_kv_cache_recurrent_state * llama_kv_cache_hybrid_recurrent_state::get_state_recurrent () const {
240- return & state_recurrent;
243+ return state_recurrent. get () ;
241244}
0 commit comments