@@ -130,42 +130,20 @@ bool llama_batch_allocr::init(
130130 warn = true ;
131131 }
132132 }
133-
134- if (warn) {
135- LLAMA_LOG_WARN (" %s: embeddings required but some input tokens were not marked as outputs -> overriding\n " , __func__);
136-
137- output.resize (batch.n_tokens , true );
138- batch.logits = output.data ();
139- }
140- }
141-
142- //
143- // compute stats
144- //
145-
146- this ->n_embd = n_embd;
147-
148- // count the outputs in this batch
149- for (int32_t i = 0 ; i < batch.n_tokens ; ++i) {
150- n_outputs += batch.logits [i] != 0 ;
151133 }
152-
153- // determine coupled sequences
154- // these are pairs of sequences that have at least one token in the input batch that is assigned to both of them
155- for (int32_t i = 0 ; i < batch.n_tokens ; ++i) {
156- const llama_seq_id s0 = batch.seq_id [i][0 ];
157-
158- for (int32_t s = 0 ; s < batch.n_seq_id [i]; ++s) {
159- const llama_seq_id s1 = batch.seq_id [i][s];
160-
161- seq_pos[s1].insert (batch.pos [i]);
162-
163- if (s > 0 ) {
164- // mark that sequence s1 is coupled to s0
165- seq_cpl[s1][s0] = true ;
166-
167- // note: tracking the other way around is not necessary for now
168- // seq_cpl[s0][s1] = true;
134+ if (batch->logits ) {
135+ if (ubatch.equal_seqs ) {
136+ for (size_t i = 0 ; i < length; ++i) {
137+ size_t id = ids[seq.offset + i];
138+ int8_t is_output = batch->logits [id];
139+ ubatch.output [ubatch.n_tokens + i] = is_output;
140+ if (is_output) { out_ids.push_back (id); }
141+ }
142+ } else {
143+ // simple split
144+ ubatch.output = batch->logits + seq.offset ;
145+ for (size_t i = 0 ; i < length; ++i) {
146+ if (ubatch.output [i] != 0 ) { out_ids.push_back (seq.offset + i); }
169147 }
170148 }
171149 }
@@ -281,141 +259,49 @@ bool llama_batch_allocr::init(
281259 }
282260 }
283261
284- if (memory) {
285- for (int32_t s0 = 0 ; s0 < LLAMA_MAX_SEQ; ++s0) {
286- for (int32_t s1 = 0 ; s1 < LLAMA_MAX_SEQ; ++s1) {
287- if (seq_cpl[s0][s1]) {
288- if (memory->seq_pos_min (s0) != memory->seq_pos_min (s1) ||
289- memory->seq_pos_max (s0) != memory->seq_pos_max (s1)) {
290- LLAMA_LOG_ERROR (" %s: sequence %d is coupled to %d in the input batch, but have divereged\n " , __func__, s0, s1);
291- return false ;
292- }
293- }
262+ llama_ubatch llama_sbatch::split_equal (size_t n_ubatch) {
263+ n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch;
264+ llama_ubatch ubatch = reserve_ubatch (n_ubatch, /* has_embd */ batch->embd != nullptr );
265+ if (!seq.empty ()) {
266+ size_t length = 0 ;
267+ size_t n_tokens_in_ubatch = 0 ;
268+ GGML_ASSERT (seq[0 ].n_seq_id > 0 ); // should not be mixed with simple splits
269+ // smallest first, because it's easier to split this way;
270+ // starting from the end to pop in constant time.
271+ for (size_t i = seq.size (); i-- > 0 ;) {
272+ llama_sbatch_seq & s = seq[i];
273+ GGML_ASSERT (s.length > 0 );
274+ if (length == 0 ) {
275+ length = s.length < n_ubatch ? s.length : n_ubatch;
294276 }
277+ add_seq_to_ubatch (ubatch, s, length);
278+ n_tokens_in_ubatch += length;
279+ // shared prompts can't be mixed with any of their sequences,
280+ // so it's safer to compute them in their own ubatch
281+ if (s.n_seq_id > 1 ) { break ; }
282+ // stop when there isn't enough space for another sequence
283+ if (length + n_tokens_in_ubatch > n_ubatch) { break ; }
295284 }
296285 }
297-
298- // disallow partial sequence sub-sets:
299- //
300- // invalid: x
301- // i: 0 1 2 ...
302- // ---------------------------------------
303- // seq_id[i][0]: 0 0 1
304- // seq_id[i][1]: 1 1 2
305- // seq_id[i][2]: 2
306- //
307- // disallow decreasing sequence positions:
308- //
309- // invalid: x
310- // i: 0 1 2 3 4 5 6 ...
311- // ---------------------------------------
312- // pos[i]: 4 5 0 1 6 2 3
313- // seq_id[i][0]: 0 0 1 1 0 1 0
314- //
315- {
316- seq_set_t cur_seq_set[LLAMA_MAX_SEQ];
317- for (int32_t s = 0 ; s < LLAMA_MAX_SEQ; ++s) {
318- cur_seq_set[s].set ();
319- }
320-
321- llama_pos cur_seq_pos[LLAMA_MAX_SEQ];
322- for (int32_t s = 0 ; s < LLAMA_MAX_SEQ; ++s) {
323- cur_seq_pos[s] = -1 ;
324- }
325-
326- for (int32_t i = 0 ; i < batch.n_tokens ; ++i) {
327- const llama_pos pos = batch.pos [i];
328-
329- for (int32_t s = 0 ; s < batch.n_seq_id [i]; ++s) {
330- const llama_seq_id seq_id = batch.seq_id [i][s];
331-
332- cur_seq_set[seq_id] &= seq_set[i];
333-
334- if (cur_seq_set[seq_id].none ()) {
335- LLAMA_LOG_ERROR (" %s: sequence %d belongs to incompatible sequence sets (not allowed)\n " , __func__, seq_id);
336- return false ;
337- }
338-
339- if (pos < cur_seq_pos[seq_id]) {
340- LLAMA_LOG_ERROR (" %s: sequence %d positions are decreasing (not allowed)\n " , __func__, seq_id);
341- return false ;
342- }
343- }
344- }
345- }
346-
347- split_reset ();
348-
349- return true ;
286+ return ubatch;
350287}
351288
352- llama_ubatch llama_batch_allocr::ubatch_reserve (uint32_t n_seq_tokens, uint32_t n_seqs) {
353- const uint32_t n_tokens = n_seq_tokens*n_seqs;
354-
355- clear ();
356- split_reset ();
357-
358- ubatches.emplace_back ();
359-
360- auto & ubatch = ubatches.back ();
361-
362- ubatch.token .resize (n_tokens);
363- ubatch.embd .clear ();
364- ubatch.pos .resize (n_tokens);
365- ubatch.n_seq_id .resize (n_tokens);
366- ubatch.seq_id .resize (n_tokens);
367- ubatch.seq_id_unq .resize (0 );
368- ubatch.seq_idx .resize (LLAMA_MAX_SEQ, -1 );
369- ubatch.output .resize (n_tokens);
370-
371- for (uint32_t s = 0 ; s < n_seqs; ++s) {
372- ubatch.seq_idx [s] = s;
373- ubatch.seq_id_unq .push_back (s);
289+ llama_ubatch llama_sbatch::split_seq (size_t n_ubatch) {
290+ n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch;
291+ llama_ubatch ubatch = reserve_ubatch (n_ubatch, /* has_embd */ batch->embd != nullptr );
292+ if (!seq.empty ()) {
293+ llama_sbatch_seq & s = seq[seq.size () - 1 ];
294+ size_t length = s.length < n_ubatch ? s.length : n_ubatch;
295+ GGML_ASSERT (s.n_seq_id > 0 ); // should not be mixed with simple splits
296+ add_seq_to_ubatch (ubatch, s, length);
374297 }
375-
376- llama_ubatch res {
377- /* .equal_seqs =*/ true ,
378- /* .n_tokens =*/ n_tokens,
379- /* .n_seq_tokens =*/ n_seq_tokens,
380- /* .n_seqs =*/ n_seqs,
381- /* .n_seqs_unq =*/ n_seqs,
382-
383- /* .token =*/ ubatch.token .data (),
384- /* .embd =*/ nullptr ,
385- /* .pos =*/ ubatch.pos .data (),
386- /* .n_seq_id =*/ ubatch.n_seq_id .data (),
387- /* .seq_id =*/ ubatch.seq_id .data (),
388- /* .seq_id_unq =*/ ubatch.seq_id_unq .data (),
389- /* .seq_idx =*/ ubatch.seq_idx .data (),
390- /* .output =*/ ubatch.output .data (),
391- };
392-
393- return res;
298+ return ubatch;
394299}
395300
396- const llama_batch & llama_batch_allocr::get_batch () const {
397- return batch;
398- }
399-
400- uint32_t llama_batch_allocr::get_n_tokens () const {
401- return batch.n_tokens ;
402- }
403-
404- uint32_t llama_batch_allocr::get_n_outputs () const {
405- return n_outputs;
406- }
407-
408- std::vector<int32_t > & llama_batch_allocr::get_out_ids () {
409- return out_ids;
410- }
411-
412- llama_pos llama_batch_allocr::seq_pos_min (llama_seq_id seq_id) const {
413- return seq_pos[seq_id].empty () ? -1 : *seq_pos[seq_id].begin ();
414- }
415-
416- llama_pos llama_batch_allocr::seq_pos_max (llama_seq_id seq_id) const {
417- return seq_pos[seq_id].empty () ? -1 : *seq_pos[seq_id].rbegin ();
418- }
301+ llama_sbatch::llama_sbatch (const llama_batch & batch, size_t n_embd, bool simple_split) {
302+ GGML_ASSERT (batch.n_tokens >= 0 );
303+ this ->batch = &batch;
304+ this ->n_embd = n_embd;
419305
420306void llama_batch_allocr::split_reset () {
421307 out_ids.clear ();
0 commit comments