@@ -95,22 +95,19 @@ llama_pos llama_kv_cache_unified_iswa::seq_pos_max(llama_seq_id seq_id) const {
9595 return kv_swa->seq_pos_max (seq_id);
9696}
9797
98- llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch (llama_batch_allocr & balloc , uint32_t n_ubatch, bool embd_all ) {
99- GGML_UNUSED (embd_all );
98+ llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch (const llama_batch & batch , uint32_t n_ubatch, bool embd_pooled, bool logits_all ) {
99+ GGML_UNUSED (embd_pooled );
100100
101101 // first try simple split
102102 do {
103- balloc. split_reset ( );
103+ auto sbatch = llama_sbatch (batch, hparams. n_embd , true , logits_all );
104104
105105 std::vector<llama_ubatch> ubatches;
106- while (true ) {
107- auto ubatch = balloc.split_simple (n_ubatch);
108106
109- if (ubatch.n_tokens == 0 ) {
110- break ;
111- }
107+ while (sbatch.n_tokens > 0 ) {
108+ auto ubatch = sbatch.split_simple (n_ubatch);
112109
113- ubatches.push_back (std::move ( ubatch)); // NOLINT
110+ ubatches.push_back (ubatch);
114111 }
115112
116113 auto heads_base = kv_base->prepare (ubatches);
@@ -125,23 +122,20 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all
125122
126123 assert (heads_base.size () == heads_swa.size ());
127124
128- return std::make_unique<llama_kv_cache_unified_iswa_context >(
129- this , std::move (heads_base), std::move (heads_swa), std::move (ubatches));
125+ return std::make_unique<llama_kv_cache_unified_iswa_state >(
126+ this , std::move (sbatch), std::move ( heads_base), std::move (heads_swa), std::move (ubatches));
130127 } while (false );
131128
132129 // if it fails, try equal split
133130 do {
134- balloc. split_reset ( );
131+ auto sbatch = llama_sbatch (batch, hparams. n_embd , false , logits_all );
135132
136133 std::vector<llama_ubatch> ubatches;
137- while (true ) {
138- auto ubatch = balloc.split_equal (n_ubatch);
139134
140- if (ubatch.n_tokens == 0 ) {
141- break ;
142- }
135+ while (sbatch.n_tokens > 0 ) {
136+ auto ubatch = sbatch.split_equal (n_ubatch);
143137
144- ubatches.push_back (std::move ( ubatch)); // NOLINT
138+ ubatches.push_back (ubatch);
145139 }
146140
147141 auto heads_base = kv_base->prepare (ubatches);
@@ -156,22 +150,22 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all
156150
157151 assert (heads_base.size () == heads_swa.size ());
158152
159- return std::make_unique<llama_kv_cache_unified_iswa_context >(
160- this , std::move (heads_base), std::move (heads_swa), std::move (ubatches));
153+ return std::make_unique<llama_kv_cache_unified_iswa_state >(
154+ this , std::move (sbatch), std::move ( heads_base), std::move (heads_swa), std::move (ubatches));
161155 } while (false );
162156
163157 // TODO: if we fail again, we should attempt different splitting strategies
164158 // but to do that properly, we first have to refactor the batches to be more flexible
165159
166- return std::make_unique<llama_kv_cache_unified_iswa_context >(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
160+ return std::make_unique<llama_kv_cache_unified_iswa_state >(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
167161}
168162
169- llama_memory_context_ptr llama_kv_cache_unified_iswa::init_full () {
170- return std::make_unique<llama_kv_cache_unified_iswa_context >(this );
163+ llama_memory_state_ptr llama_kv_cache_unified_iswa::init_full () {
164+ return std::make_unique<llama_kv_cache_unified_iswa_state >(this );
171165}
172166
173- llama_memory_context_ptr llama_kv_cache_unified_iswa::init_update (llama_context * lctx, bool optimize) {
174- return std::make_unique<llama_kv_cache_unified_iswa_context >(this , lctx, optimize);
167+ llama_memory_state_ptr llama_kv_cache_unified_iswa::init_update (llama_context * lctx, bool optimize) {
168+ return std::make_unique<llama_kv_cache_unified_iswa_state >(this , lctx, optimize);
175169}
176170
177171bool llama_kv_cache_unified_iswa::get_can_shift () const {
@@ -197,46 +191,52 @@ llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_swa() const {
197191}
198192
199193//
200- // llama_kv_cache_unified_iswa_context
194+ // llama_kv_cache_unified_iswa_state
201195//
202196
203- llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context (llama_memory_status status) : status(status) {}
197+ llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state (llama_memory_status status) : status(status) {}
204198
205- llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context (
206- llama_kv_cache_unified_iswa * kv) :
207- ctx_base(kv->get_base ()->init_full()),
208- ctx_swa (kv->get_swa ()->init_full()),
209- status(llama_memory_status_combine(ctx_base->get_status (), ctx_swa->get_status())) {
199+ llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state (
200+ llama_kv_cache_unified_iswa * kv) : status(LLAMA_MEMORY_STATUS_SUCCESS) {
201+ state_base = kv->get_base ()->init_full ();
202+ state_swa = kv->get_swa ()->init_full ();
203+
204+ status = llama_memory_status_combine (state_base->get_status (), state_swa->get_status ());
210205}
211206
212- llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context (
207+ llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state (
213208 llama_kv_cache_unified_iswa * kv,
214209 llama_context * lctx,
215- bool optimize) :
216- ctx_base(kv->get_base ()->init_update(lctx, optimize)),
217- ctx_swa (kv->get_swa ()->init_update(lctx, optimize)),
218- status(llama_memory_status_combine(ctx_base->get_status (), ctx_swa->get_status())) {
210+ bool optimize) : status(LLAMA_MEMORY_STATUS_SUCCESS) {
211+ state_base = kv->get_base ()->init_update (lctx, optimize);
212+ state_swa = kv->get_swa ()->init_update (lctx, optimize);
213+
214+ status = llama_memory_status_combine (state_base->get_status (), state_swa->get_status ());
219215}
220216
221- llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context (
217+ llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state (
222218 llama_kv_cache_unified_iswa * kv,
219+ llama_sbatch sbatch,
223220 std::vector<uint32_t > heads_base,
224221 std::vector<uint32_t > heads_swa,
225- std::vector<llama_ubatch> ubatches) :
226- ubatches(std::move(ubatches)),
222+ std::vector<llama_ubatch> ubatches)
223+ : status(LLAMA_MEMORY_STATUS_SUCCESS),
224+ sbatch(std::move(sbatch)),
225+ ubatches(std::move(ubatches)) {
227226 // note: here we copy the ubatches. not sure if this is ideal
228- ctx_base(new llama_kv_cache_unified_context(kv->get_base (), std::move(heads_base), this->ubatches)),
229- ctx_swa (new llama_kv_cache_unified_context(kv->get_swa (), std::move(heads_swa), this->ubatches)),
230- status(llama_memory_status_combine(ctx_base->get_status (), ctx_swa->get_status())) {
227+ state_base.reset (new llama_kv_cache_unified_state (kv->get_base (), {}, std::move (heads_base), this ->ubatches ));
228+ state_swa .reset (new llama_kv_cache_unified_state (kv->get_swa (), {}, std::move (heads_swa), this ->ubatches ));
229+
230+ status = llama_memory_status_combine (state_base->get_status (), state_swa->get_status ());
231231}
232232
233- llama_kv_cache_unified_iswa_context :: ~llama_kv_cache_unified_iswa_context () = default ;
233+ llama_kv_cache_unified_iswa_state :: ~llama_kv_cache_unified_iswa_state () = default ;
234234
235- bool llama_kv_cache_unified_iswa_context ::next () {
235+ bool llama_kv_cache_unified_iswa_state ::next () {
236236 assert (status == LLAMA_MEMORY_STATUS_SUCCESS);
237237
238- ctx_base ->next ();
239- ctx_swa ->next ();
238+ state_base ->next ();
239+ state_swa ->next ();
240240
241241 if (++i_next >= ubatches.size ()) {
242242 return false ;
@@ -245,35 +245,41 @@ bool llama_kv_cache_unified_iswa_context::next() {
245245 return true ;
246246}
247247
248- bool llama_kv_cache_unified_iswa_context ::apply () {
249- assert (! llama_memory_status_is_fail ( status) );
248+ bool llama_kv_cache_unified_iswa_state ::apply () {
249+ assert (status == LLAMA_MEMORY_STATUS_SUCCESS );
250250
251251 bool res = true ;
252252
253- res = res & ctx_base ->apply ();
254- res = res & ctx_swa ->apply ();
253+ res = res & state_base ->apply ();
254+ res = res & state_swa ->apply ();
255255
256256 return res;
257257}
258258
259- llama_memory_status llama_kv_cache_unified_iswa_context::get_status () const {
259+ std::vector<int64_t > & llama_kv_cache_unified_iswa_state::out_ids () {
260+ assert (status == LLAMA_MEMORY_STATUS_SUCCESS);
261+
262+ return sbatch.out_ids ;
263+ }
264+
265+ llama_memory_status llama_kv_cache_unified_iswa_state::get_status () const {
260266 return status;
261267}
262268
263- const llama_ubatch & llama_kv_cache_unified_iswa_context ::get_ubatch () const {
269+ const llama_ubatch & llama_kv_cache_unified_iswa_state ::get_ubatch () const {
264270 assert (status == LLAMA_MEMORY_STATUS_SUCCESS);
265271
266272 return ubatches[i_next];
267273}
268274
269- const llama_kv_cache_unified_context * llama_kv_cache_unified_iswa_context ::get_base () const {
275+ const llama_kv_cache_unified_state * llama_kv_cache_unified_iswa_state ::get_base () const {
270276 assert (status == LLAMA_MEMORY_STATUS_SUCCESS);
271277
272- return static_cast <const llama_kv_cache_unified_context *>(ctx_base .get ());
278+ return static_cast <const llama_kv_cache_unified_state *>(state_base .get ());
273279}
274280
275- const llama_kv_cache_unified_context * llama_kv_cache_unified_iswa_context ::get_swa () const {
281+ const llama_kv_cache_unified_state * llama_kv_cache_unified_iswa_state ::get_swa () const {
276282 assert (status == LLAMA_MEMORY_STATUS_SUCCESS);
277283
278- return static_cast <const llama_kv_cache_unified_context *>(ctx_swa .get ());
284+ return static_cast <const llama_kv_cache_unified_state *>(state_swa .get ());
279285}
0 commit comments