@@ -7,9 +7,9 @@ llama_ubatch llama_sbatch::reserve_ubatch(size_t n_ubatch, bool has_embd) {
77 // clear empty sequences
88 // the previous ubatch is assumed to be gone,
99 // so nothing should refer to values in these sequences anymore.
10- for (size_t i = seq .size (); i-- > 0 ;) {
11- if (seq [i].length == 0 ) {
12- seq .pop_back ();
10+ for (size_t i = seqs .size (); i-- > 0 ;) {
11+ if (seqs [i].length == 0 ) {
12+ seqs .pop_back ();
1313 } else {
1414 break ;
1515 }
@@ -36,48 +36,48 @@ llama_ubatch llama_sbatch::reserve_ubatch(size_t n_ubatch, bool has_embd) {
3636}
3737
3838void llama_sbatch::add_seq_to_ubatch (llama_ubatch & ubatch, llama_sbatch_seq & seq, size_t length) {
39- GGML_ASSERT (batch != nullptr );
39+ GGML_ASSERT (batch_ptr != nullptr );
4040 GGML_ASSERT (length <= seq.length );
4141 // Can only add sequences of equal lengths to a batch,
4242 // otherwise it isn't clear to which sequence a token belongs
4343 GGML_ASSERT (seq.n_seq_id == 0 || ubatch.n_seqs == 0 || length == (size_t ) ubatch.n_tokens / ubatch.n_seqs );
4444 GGML_ASSERT ((seq.n_seq_id != 0 ) == ubatch.equal_seqs );
4545 // NOTE: loops are separated for cache-friendliness
46- if (batch ->token ) {
46+ if (batch_ptr ->token ) {
4747 if (ubatch.equal_seqs ) {
4848 for (size_t i = 0 ; i < length; ++i) {
49- ubatch.token [ubatch.n_tokens + i] = batch ->token [ids[seq.offset + i]];
49+ ubatch.token [ubatch.n_tokens + i] = batch_ptr ->token [ids[seq.offset + i]];
5050 }
5151 } else {
5252 // simple split
53- ubatch.token = batch ->token + seq.offset ;
53+ ubatch.token = batch_ptr ->token + seq.offset ;
5454 }
5555 } else {
5656 ubatch.token = nullptr ;
5757 }
58- if (batch ->embd ) {
58+ if (batch_ptr ->embd ) {
5959 if (ubatch.equal_seqs ) {
6060 for (size_t i = 0 ; i < length; ++i) {
6161 memcpy (
6262 ubatch.embd + (n_embd * (ubatch.n_tokens + i)),
63- batch ->embd + (n_embd * ids[seq.offset + i]),
63+ batch_ptr ->embd + (n_embd * ids[seq.offset + i]),
6464 n_embd * sizeof (float )
6565 );
6666 }
6767 } else {
6868 // simple split
69- ubatch.embd = batch ->embd + (n_embd * seq.offset );
69+ ubatch.embd = batch_ptr ->embd + (n_embd * seq.offset );
7070 }
7171 } else {
7272 ubatch.embd = nullptr ;
7373 }
7474 if (ubatch.equal_seqs ) {
7575 for (size_t i = 0 ; i < length; ++i) {
76- ubatch.pos [ubatch.n_tokens + i] = batch ->pos [ids[seq.offset + i]];
76+ ubatch.pos [ubatch.n_tokens + i] = batch_ptr ->pos [ids[seq.offset + i]];
7777 }
7878 } else {
7979 // simple split
80- ubatch.pos = batch ->pos + seq.offset ;
80+ ubatch.pos = batch_ptr ->pos + seq.offset ;
8181 }
8282 if (ubatch.equal_seqs ) {
8383 ubatch.n_seq_id [ubatch.n_seqs ] = seq.n_seq_id ;
@@ -86,33 +86,33 @@ void llama_sbatch::add_seq_to_ubatch(llama_ubatch & ubatch, llama_sbatch_seq & s
8686 }
8787 } else {
8888 // simple split
89- if (batch ->n_seq_id ) {
90- ubatch.n_seq_id = batch ->n_seq_id + seq.offset ;
89+ if (batch_ptr ->n_seq_id ) {
90+ ubatch.n_seq_id = batch_ptr ->n_seq_id + seq.offset ;
9191 } else {
9292 for (size_t i = 0 ; i < length; ++i) {
9393 ubatch.n_seq_id [ubatch.n_seqs + i] = 1 ;
9494 }
9595 }
96- if (batch ->seq_id ) {
97- ubatch.seq_id = batch ->seq_id + seq.offset ;
96+ if (batch_ptr ->seq_id ) {
97+ ubatch.seq_id = batch_ptr ->seq_id + seq.offset ;
9898 }
9999 }
100100 if (logits_all) {
101101 for (size_t i = 0 ; i < length; ++i) {
102102 ubatch.output [ubatch.n_tokens + i] = 1 ;
103103 out_ids.push_back (ids[seq.offset + i]);
104104 }
105- } else if (batch ->logits ) {
105+ } else if (batch_ptr ->logits ) {
106106 if (ubatch.equal_seqs ) {
107107 for (size_t i = 0 ; i < length; ++i) {
108108 size_t id = ids[seq.offset + i];
109- int8_t is_output = batch ->logits [id];
109+ int8_t is_output = batch_ptr ->logits [id];
110110 ubatch.output [ubatch.n_tokens + i] = is_output;
111111 if (is_output) { out_ids.push_back (id); }
112112 }
113113 } else {
114114 // simple split
115- ubatch.output = batch ->logits + seq.offset ;
115+ ubatch.output = batch_ptr ->logits + seq.offset ;
116116 for (size_t i = 0 ; i < length; ++i) {
117117 if (ubatch.output [i] != 0 ) { out_ids.push_back (seq.offset + i); }
118118 }
@@ -139,28 +139,28 @@ void llama_sbatch::add_seq_to_ubatch(llama_ubatch & ubatch, llama_sbatch_seq & s
139139
140140llama_ubatch llama_sbatch::split_simple (size_t n_ubatch) {
141141 n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch;
142- llama_ubatch ubatch = reserve_ubatch (n_ubatch, /* has_embd */ batch ->embd != nullptr );
142+ llama_ubatch ubatch = reserve_ubatch (n_ubatch, /* has_embd */ batch_ptr ->embd != nullptr );
143143 ubatch.equal_seqs = false ;
144- if (!seq .empty ()) {
145- llama_sbatch_seq & s = seq [0 ];
144+ if (!seqs .empty ()) {
145+ llama_sbatch_seq & s = seqs [0 ];
146146 size_t length = s.length < n_ubatch ? s.length : n_ubatch;
147- GGML_ASSERT (seq .size () == 1 && s.n_seq_id == 0 ); // don't mix with other splits
147+ GGML_ASSERT (seqs .size () == 1 && s.n_seq_id == 0 ); // don't mix with other splits
148148 add_seq_to_ubatch (ubatch, s, length);
149149 }
150150 return ubatch;
151151}
152152
153153llama_ubatch llama_sbatch::split_equal (size_t n_ubatch) {
154154 n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch;
155- llama_ubatch ubatch = reserve_ubatch (n_ubatch, /* has_embd */ batch ->embd != nullptr );
156- if (!seq .empty ()) {
155+ llama_ubatch ubatch = reserve_ubatch (n_ubatch, /* has_embd */ batch_ptr ->embd != nullptr );
156+ if (!seqs .empty ()) {
157157 size_t length = 0 ;
158158 size_t n_tokens_in_ubatch = 0 ;
159- GGML_ASSERT (seq [0 ].n_seq_id > 0 ); // should not be mixed with simple splits
159+ GGML_ASSERT (seqs [0 ].n_seq_id > 0 ); // should not be mixed with simple splits
160160 // smallest first, because it's easier to split this way;
161161 // starting from the end to pop in constant time.
162- for (size_t i = seq .size (); i-- > 0 ;) {
163- llama_sbatch_seq & s = seq [i];
162+ for (size_t i = seqs .size (); i-- > 0 ;) {
163+ llama_sbatch_seq & s = seqs [i];
164164 GGML_ASSERT (s.length > 0 );
165165 if (length == 0 ) {
166166 length = s.length < n_ubatch ? s.length : n_ubatch;
@@ -179,33 +179,34 @@ llama_ubatch llama_sbatch::split_equal(size_t n_ubatch) {
179179
180180llama_ubatch llama_sbatch::split_seq (size_t n_ubatch) {
181181 n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch;
182- llama_ubatch ubatch = reserve_ubatch (n_ubatch, /* has_embd */ batch ->embd != nullptr );
183- if (!seq .empty ()) {
184- llama_sbatch_seq & s = seq[seq. size () - 1 ] ;
182+ llama_ubatch ubatch = reserve_ubatch (n_ubatch, /* has_embd */ batch_ptr ->embd != nullptr );
183+ if (!seqs .empty ()) {
184+ llama_sbatch_seq & s = seqs. back () ;
185185 size_t length = s.length < n_ubatch ? s.length : n_ubatch;
186186 GGML_ASSERT (s.n_seq_id > 0 ); // should not be mixed with simple splits
187187 add_seq_to_ubatch (ubatch, s, length);
188188 }
189189 return ubatch;
190190}
191191
192- void llama_sbatch::from_batch (const llama_batch & batch, size_t n_embd , bool simple_split, bool logits_all ) {
192+ void llama_sbatch::from_batch (const llama_batch & batch, size_t n_embd_cur , bool simple_split, bool logits_all_cur ) {
193193 GGML_ASSERT (batch.n_tokens >= 0 );
194- this ->batch = &batch;
195- this ->n_embd = n_embd;
196- this ->logits_all = logits_all;
194+
195+ batch_ptr = &batch;
196+ n_embd = n_embd_cur;
197+ logits_all = logits_all_cur;
197198
198199 n_tokens = batch.n_tokens ;
199200 ids.resize (n_tokens);
200201 out_ids.clear ();
201- // TODO: reserve out_ids and seq
202+ // TODO: reserve out_ids and seqs
202203
203204 for (size_t i = 0 ; i < n_tokens; ++i) {
204205 ids[i] = i;
205206 }
206207 if (simple_split) {
207- seq .resize (1 );
208- llama_sbatch_seq & s = seq [0 ];
208+ seqs .resize (1 );
209+ llama_sbatch_seq & s = seqs [0 ];
209210 s.n_seq_id = 0 ;
210211 s.seq_id = nullptr ;
211212 s.offset = 0 ;
@@ -259,11 +260,11 @@ void llama_sbatch::from_batch(const llama_batch & batch, size_t n_embd, bool sim
259260 }
260261 }
261262 llama_sbatch_seq new_seq = {n_seqs, seq_ids, i, 1 };
262- seq .push_back (new_seq);
263- last_seq = &seq .back ();
263+ seqs .push_back (new_seq);
264+ last_seq = &seqs .back ();
264265 }
265266 // keep shared prompts first at the end, then sort by length descending.
266- std::sort (seq .begin (), seq .end (),
267+ std::sort (seqs .begin (), seqs .end (),
267268 [](llama_sbatch_seq & a, llama_sbatch_seq & b) {
268269 if (a.n_seq_id == b.n_seq_id ) {
269270 return a.length > b.length ;
0 commit comments