@@ -102,7 +102,7 @@ int main(int argc, char ** argv) {
102102
103103 // create a llama_batch
104104 // we use this object to submit token data for decoding
105- llama_batch batch = llama_batch_init (std::max (tokens_list.size (), (size_t ) n_parallel), 0 , n_parallel);
105+ llama_batch_ext * batch = llama_batch_ext_init (std::max (tokens_list.size (), (size_t ) n_parallel), n_parallel);
106106
107107 std::vector<llama_seq_id> seq_ids (n_parallel, 0 );
108108 for (int32_t i = 0 ; i < n_parallel; ++i) {
@@ -111,12 +111,12 @@ int main(int argc, char ** argv) {
111111
112112 // evaluate the initial prompt
113113 for (size_t i = 0 ; i < tokens_list.size (); ++i) {
114- common_batch_add (batch, tokens_list[i], i, seq_ids, false );
114+ llama_batch_ext_add_text (batch, tokens_list[i], i, seq_ids. data (), seq_ids. size () , false );
115115 }
116- GGML_ASSERT (batch. n_tokens == (int ) tokens_list.size ());
116+ GGML_ASSERT (llama_batch_ext_get_n_tokens ( batch) == (int ) tokens_list.size ());
117117
118118 if (llama_model_has_encoder (model)) {
119- if (llama_encode (ctx, batch)) {
119+ if (llama_encode_ext (ctx, batch)) {
120120 LOG_ERR (" %s : failed to eval\n " , __func__);
121121 return 1 ;
122122 }
@@ -126,14 +126,14 @@ int main(int argc, char ** argv) {
126126 decoder_start_token_id = llama_vocab_bos (vocab);
127127 }
128128
129- common_batch_clear (batch);
130- common_batch_add (batch, decoder_start_token_id, 0 , seq_ids, false );
129+ llama_batch_ext_clear (batch);
130+ llama_batch_ext_add_text (batch, decoder_start_token_id, 0 , seq_ids. data (), seq_ids. size () , false );
131131 }
132132
133133 // llama_decode will output logits only for the last token of the prompt
134- batch. logits [batch. n_tokens - 1 ] = true ;
134+ llama_batch_ext_set_logits_last ( batch) ;
135135
136- if (llama_decode (ctx, batch) != 0 ) {
136+ if (llama_decode_ext (ctx, batch) != 0 ) {
137137 LOG_ERR (" %s: llama_decode() failed\n " , __func__);
138138 return 1 ;
139139 }
@@ -155,16 +155,16 @@ int main(int argc, char ** argv) {
155155
156156 // remember the batch index of the last token for each parallel sequence
157157 // we need this to determine which logits to sample from
158- std::vector<int32_t > i_batch (n_parallel, batch. n_tokens - 1 );
158+ std::vector<int32_t > i_batch (n_parallel, llama_batch_ext_get_n_tokens ( batch) - 1 );
159159
160- int n_cur = batch. n_tokens ;
160+ int n_cur = llama_batch_ext_get_n_tokens ( batch) ;
161161 int n_decode = 0 ;
162162
163163 const auto t_main_start = ggml_time_us ();
164164
165165 while (n_cur <= n_predict) {
166166 // prepare the next batch
167- common_batch_clear (batch);
167+ llama_batch_ext_clear (batch);
168168
169169 // sample the next token for each parallel sequence / stream
170170 for (int32_t i = 0 ; i < n_parallel; ++i) {
@@ -193,23 +193,23 @@ int main(int argc, char ** argv) {
193193
194194 streams[i] += common_token_to_piece (ctx, new_token_id);
195195
196- i_batch[i] = batch. n_tokens ;
196+ i_batch[i] = llama_batch_ext_get_n_tokens ( batch) ;
197197
198198 // push this new token for next evaluation
199- common_batch_add (batch, new_token_id, n_cur, { i }, true );
199+ llama_batch_ext_add_text (batch, new_token_id, n_cur, &i, 1 , false );
200200
201201 n_decode += 1 ;
202202 }
203203
204204 // all streams are finished
205- if (batch. n_tokens == 0 ) {
205+ if (llama_batch_ext_get_n_tokens ( batch) == 0 ) {
206206 break ;
207207 }
208208
209209 n_cur += 1 ;
210210
211211 // evaluate the current batch with the transformer model
212- if (llama_decode (ctx, batch)) {
212+ if (llama_decode_ext (ctx, batch)) {
213213 LOG_ERR (" %s : failed to eval, return code %d\n " , __func__, 1 );
214214 return 1 ;
215215 }
@@ -234,7 +234,7 @@ int main(int argc, char ** argv) {
234234
235235 fprintf (stderr, " \n " );
236236
237- llama_batch_free (batch);
237+ llama_batch_ext_free (batch);
238238
239239 llama_sampler_free (smpl);
240240 llama_free (ctx);
0 commit comments