@@ -100,7 +100,6 @@ llama_memory_state_ptr llama_kv_cache_hybrid_recurrent::init_full() {
100100
101101llama_memory_state_ptr llama_kv_cache_hybrid_recurrent::init_update (llama_context * lctx, bool optimize) {
102102 return std::make_unique<llama_kv_cache_hybrid_recurrent_state>(
103- this ,
104103 static_cast <llama_kv_cache_unified_state *>( kv_attn ->init_update (lctx, optimize).release ()),
105104 static_cast <llama_kv_cache_recurrent_state *>(kv_recurrent->init_update (lctx, optimize).release ()));
106105}
@@ -179,16 +178,13 @@ llama_kv_cache_hybrid_recurrent_state::llama_kv_cache_hybrid_recurrent_state(lla
179178
180179llama_kv_cache_hybrid_recurrent_state::llama_kv_cache_hybrid_recurrent_state (llama_kv_cache_hybrid_recurrent * kv)
181180 : status(LLAMA_MEMORY_STATUS_SUCCESS),
182- kv(kv),
183181 state_attn(new llama_kv_cache_unified_state(kv->get_kv_attn ())),
184182 state_recurrent(new llama_kv_cache_recurrent_state(status, kv->get_kv_recurrent ())) {}
185183
186184llama_kv_cache_hybrid_recurrent_state::llama_kv_cache_hybrid_recurrent_state (
187- llama_kv_cache_hybrid_recurrent * kv,
188185 llama_kv_cache_unified_state * state_unified,
189186 llama_kv_cache_recurrent_state * state_recurrent)
190187 : status(LLAMA_MEMORY_STATUS_NO_UPDATE),
191- kv(kv),
192188 state_attn(state_unified),
193189 state_recurrent(state_recurrent) {}
194190
@@ -198,20 +194,19 @@ llama_kv_cache_hybrid_recurrent_state::llama_kv_cache_hybrid_recurrent_state(
198194 std::vector<uint32_t > heads_attn,
199195 std::vector<llama_ubatch> ubatches)
200196 : status(LLAMA_MEMORY_STATUS_SUCCESS),
201- kv(kv),
202197 sbatch(std::move(sbatch)),
203- heads_attn(std::move(heads_attn)),
204198 ubatches(std::move(ubatches)),
205- // NOTE: these child states are only used as wrapper APIs for the
206- // const methods, so we use the "init full" signature since the
207- // actual state is not used.
208- state_attn(new llama_kv_cache_unified_state(kv->get_kv_attn ())),
209- state_recurrent(new llama_kv_cache_recurrent_state(LLAMA_MEMORY_STATUS_SUCCESS, kv->get_kv_recurrent ())) {}
199+ // note: here we copy the ubatches. not sure if this is ideal
200+ state_attn(new llama_kv_cache_unified_state(kv->get_kv_attn (), {}, std::move(heads_attn), this->ubatches)),
201+ state_recurrent(new llama_kv_cache_recurrent_state(LLAMA_MEMORY_STATUS_SUCCESS, kv->get_kv_recurrent (), {}, this ->ubatches)) {}
210202
211203
212204bool llama_kv_cache_hybrid_recurrent_state::next () {
213205 assert (status == LLAMA_MEMORY_STATUS_SUCCESS);
214206
207+ state_attn ->next ();
208+ state_recurrent->next ();
209+
215210 if (++i_next >= ubatches.size ()) {
216211 return false ;
217212 }
@@ -222,10 +217,12 @@ bool llama_kv_cache_hybrid_recurrent_state::next() {
222217bool llama_kv_cache_hybrid_recurrent_state::apply () {
223218 assert (status == LLAMA_MEMORY_STATUS_SUCCESS);
224219
225- kv->get_kv_attn () ->apply_ubatch (heads_attn[i_next], ubatches[i_next]);
226- kv->get_kv_recurrent ()->find_slot (ubatches[i_next]);
220+ bool res = true ;
227221
228- return true ;
222+ res = res & state_attn ->apply ();
223+ res = res & state_recurrent->apply ();
224+
225+ return res;
229226}
230227
231228std::vector<int64_t > & llama_kv_cache_hybrid_recurrent_state::out_ids () {
0 commit comments