@@ -123,26 +123,16 @@ llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(const llama_batch
123123
124124 assert (heads_base.size () == heads_swa.size ());
125125
126- return std::make_unique<llama_kv_cache_unified_iswa_state>(LLAMA_MEMORY_STATUS_SUCCESS,
126+ return std::make_unique<llama_kv_cache_unified_iswa_state>(
127127 this , std::move (sbatch), std::move (heads_base), std::move (heads_swa), std::move (ubatches));
128128}
129129
130130llama_memory_state_ptr llama_kv_cache_unified_iswa::init_full () {
131- return std::make_unique<llama_kv_cache_unified_iswa_state>(LLAMA_MEMORY_STATUS_SUCCESS, this );
131+ return std::make_unique<llama_kv_cache_unified_iswa_state>(this );
132132}
133133
134- bool llama_kv_cache_unified_iswa::update (llama_context & lctx) {
135- bool res = false ;
136-
137- res = res | kv_base->update (lctx);
138- res = res | kv_swa ->update (lctx);
139-
140- return res;
141- }
142-
143- void llama_kv_cache_unified_iswa::defrag_sched (float thold) {
144- kv_base->defrag_sched (thold);
145- kv_swa ->defrag_sched (thold);
134+ llama_memory_state_ptr llama_kv_cache_unified_iswa::init_update (llama_context * lctx, bool optimize) {
135+ return std::make_unique<llama_kv_cache_unified_iswa_state>(this , lctx, optimize);
146136}
147137
148138bool llama_kv_cache_unified_iswa::get_can_shift () const {
@@ -174,25 +164,48 @@ llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_swa() const {
174164llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state (llama_memory_status status) : status(status) {}
175165
176166llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state (
177- llama_memory_status status,
178- llama_kv_cache_unified_iswa * kv) : status(status) {
179- state_base.reset (new llama_kv_cache_unified_state (status, kv->get_base ()));
180- state_swa .reset (new llama_kv_cache_unified_state (status, kv->get_swa ()));
167+ llama_kv_cache_unified_iswa * kv) : status(LLAMA_MEMORY_STATUS_SUCCESS) {
168+ state_base = kv->get_base ()->init_full ();
169+ state_swa = kv->get_swa ()->init_full ();
170+ }
171+
172+ llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state (
173+ llama_kv_cache_unified_iswa * kv,
174+ llama_context * lctx,
175+ bool optimize) : status(LLAMA_MEMORY_STATUS_SUCCESS) {
176+ state_base = kv->get_base ()->init_update (lctx, optimize);
177+ state_swa = kv->get_swa ()->init_update (lctx, optimize);
178+
179+ // TODO: this is very ugly - how to make it simpler?
180+ // the llama_memory_status enum is not very well designed
181+ if (state_base->get_status () != LLAMA_MEMORY_STATUS_SUCCESS && state_base->get_status () != LLAMA_MEMORY_STATUS_NO_UPDATE) {
182+ status = state_base->get_status ();
183+ return ;
184+ }
185+
186+ if (state_swa->get_status () != LLAMA_MEMORY_STATUS_SUCCESS && state_swa->get_status () != LLAMA_MEMORY_STATUS_NO_UPDATE) {
187+ status = state_swa->get_status ();
188+ return ;
189+ }
190+
191+ if (state_base->get_status () == LLAMA_MEMORY_STATUS_NO_UPDATE && state_swa->get_status () == LLAMA_MEMORY_STATUS_NO_UPDATE) {
192+ status = LLAMA_MEMORY_STATUS_NO_UPDATE;
193+ return ;
194+ }
181195}
182196
183197llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state (
184- llama_memory_status status,
185198 llama_kv_cache_unified_iswa * kv,
186199 llama_sbatch sbatch,
187200 std::vector<uint32_t > heads_base,
188201 std::vector<uint32_t > heads_swa,
189202 std::vector<llama_ubatch> ubatches)
190- : status(status ),
203+ : status(LLAMA_MEMORY_STATUS_SUCCESS ),
191204 sbatch(std::move(sbatch)),
192205 ubatches(std::move(ubatches)) {
193206 // note: here we copy the ubatches. not sure if this is ideal
194- state_base.reset (new llama_kv_cache_unified_state (status, kv->get_base (), {}, std::move (heads_base), this ->ubatches ));
195- state_swa .reset (new llama_kv_cache_unified_state (status, kv->get_swa (), {}, std::move (heads_swa), this ->ubatches ));
207+ state_base.reset (new llama_kv_cache_unified_state (kv->get_base (), {}, std::move (heads_base), this ->ubatches ));
208+ state_swa .reset (new llama_kv_cache_unified_state (kv->get_swa (), {}, std::move (heads_swa), this ->ubatches ));
196209 }
197210
198211llama_kv_cache_unified_iswa_state:: ~llama_kv_cache_unified_iswa_state () = default ;
@@ -239,11 +252,11 @@ const llama_ubatch & llama_kv_cache_unified_iswa_state::get_ubatch() const {
239252const llama_kv_cache_unified_state * llama_kv_cache_unified_iswa_state::get_base () const {
240253 assert (status == LLAMA_MEMORY_STATUS_SUCCESS);
241254
242- return state_base.get ();
255+ return static_cast < const llama_kv_cache_unified_state *>( state_base.get () );
243256}
244257
245258const llama_kv_cache_unified_state * llama_kv_cache_unified_iswa_state::get_swa () const {
246259 assert (status == LLAMA_MEMORY_STATUS_SUCCESS);
247260
248- return state_swa.get ();
261+ return static_cast < const llama_kv_cache_unified_state *>( state_swa.get () );
249262}
0 commit comments