@@ -123,26 +123,16 @@ llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(const llama_batch
123123
124124    assert (heads_base.size () == heads_swa.size ());
125125
126-     return  std::make_unique<llama_kv_cache_unified_iswa_state>(LLAMA_MEMORY_STATUS_SUCCESS, 
126+     return  std::make_unique<llama_kv_cache_unified_iswa_state>(
127127            this , std::move (sbatch), std::move (heads_base), std::move (heads_swa), std::move (ubatches));
128128}
129129
130130llama_memory_state_ptr llama_kv_cache_unified_iswa::init_full () {
131-     return  std::make_unique<llama_kv_cache_unified_iswa_state>(LLAMA_MEMORY_STATUS_SUCCESS,  this );
131+     return  std::make_unique<llama_kv_cache_unified_iswa_state>(this );
132132}
133133
134- bool  llama_kv_cache_unified_iswa::update (llama_context & lctx) {
135-     bool  res = false ;
136- 
137-     res = res | kv_base->update (lctx);
138-     res = res | kv_swa ->update (lctx);
139- 
140-     return  res;
141- }
142- 
143- void  llama_kv_cache_unified_iswa::defrag_sched (float  thold) {
144-     kv_base->defrag_sched (thold);
145-     kv_swa ->defrag_sched (thold);
134+ llama_memory_state_ptr llama_kv_cache_unified_iswa::init_update (llama_context * lctx, bool  optimize) {
135+     return  std::make_unique<llama_kv_cache_unified_iswa_state>(this , lctx, optimize);
146136}
147137
148138bool  llama_kv_cache_unified_iswa::get_can_shift () const  {
@@ -174,26 +164,38 @@ llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_swa() const {
174164llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state (llama_memory_status status) : status(status) {}
175165
176166llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state (
177-         llama_memory_status status,
178-         llama_kv_cache_unified_iswa * kv) : status(status) {
179-     state_base.reset (new  llama_kv_cache_unified_state (status, kv->get_base ()));
180-     state_swa .reset (new  llama_kv_cache_unified_state (status, kv->get_swa  ()));
167+         llama_kv_cache_unified_iswa * kv) : status(LLAMA_MEMORY_STATUS_SUCCESS) {
168+     state_base = kv->get_base ()->init_full ();
169+     state_swa  = kv->get_swa  ()->init_full ();
170+ 
171+     status = llama_memory_status_combine (state_base->get_status (), state_swa->get_status ());
172+ }
173+ 
174+ llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state (
175+         llama_kv_cache_unified_iswa * kv,
176+         llama_context * lctx,
177+         bool  optimize) : status(LLAMA_MEMORY_STATUS_SUCCESS) {
178+     state_base = kv->get_base ()->init_update (lctx, optimize);
179+     state_swa  = kv->get_swa  ()->init_update (lctx, optimize);
180+ 
181+     status = llama_memory_status_combine (state_base->get_status (), state_swa->get_status ());
181182}
182183
183184llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state (
184-         llama_memory_status status,
185185        llama_kv_cache_unified_iswa * kv,
186186        llama_sbatch sbatch,
187187        std::vector<uint32_t > heads_base,
188188        std::vector<uint32_t > heads_swa,
189189        std::vector<llama_ubatch> ubatches)
190-     : status(status),
191-     sbatch(std::move(sbatch)),
192-     ubatches(std::move(ubatches)) {
193-         //  note: here we copy the ubatches. not sure if this is ideal
194-         state_base.reset (new  llama_kv_cache_unified_state (status, kv->get_base (), {}, std::move (heads_base), this ->ubatches ));
195-         state_swa .reset (new  llama_kv_cache_unified_state (status, kv->get_swa  (), {}, std::move (heads_swa),  this ->ubatches ));
196-     }
190+         : status(LLAMA_MEMORY_STATUS_SUCCESS),
191+         sbatch(std::move(sbatch)),
192+         ubatches(std::move(ubatches)) {
193+     //  note: here we copy the ubatches. not sure if this is ideal
194+     state_base.reset (new  llama_kv_cache_unified_state (kv->get_base (), {}, std::move (heads_base), this ->ubatches ));
195+     state_swa .reset (new  llama_kv_cache_unified_state (kv->get_swa  (), {}, std::move (heads_swa),  this ->ubatches ));
196+ 
197+     status = llama_memory_status_combine (state_base->get_status (), state_swa->get_status ());
198+ }
197199
198200llama_kv_cache_unified_iswa_state:: ~llama_kv_cache_unified_iswa_state () = default ;
199201
@@ -233,17 +235,18 @@ llama_memory_status llama_kv_cache_unified_iswa_state::get_status() const {
233235
234236const  llama_ubatch & llama_kv_cache_unified_iswa_state::get_ubatch () const  {
235237    assert (status == LLAMA_MEMORY_STATUS_SUCCESS);
238+ 
236239    return  ubatches[i_next];
237240}
238241
239242const  llama_kv_cache_unified_state * llama_kv_cache_unified_iswa_state::get_base () const  {
240243    assert (status == LLAMA_MEMORY_STATUS_SUCCESS);
241244
242-     return  state_base.get ();
245+     return  static_cast < const  llama_kv_cache_unified_state *>( state_base.get () );
243246}
244247
245248const  llama_kv_cache_unified_state * llama_kv_cache_unified_iswa_state::get_swa ()  const  {
246249    assert (status == LLAMA_MEMORY_STATUS_SUCCESS);
247250
248-     return  state_swa.get ();
251+     return  static_cast < const  llama_kv_cache_unified_state *>( state_swa.get () );
249252}
0 commit comments