@@ -74,13 +74,6 @@ static std::vector<chunk> chunk_file(const std::string & filename, int chunk_siz
7474 return chunks;
7575}
7676
77- static void batch_add_seq (llama_batch_ext * batch, const std::vector<int32_t > & tokens, llama_seq_id seq_id) {
78- const size_t n_tokens = tokens.size ();
79- for (size_t i = 0 ; i < n_tokens; i++) {
80- llama_batch_ext_add_text (batch, tokens[i], i, &seq_id, 1 , true );
81- }
82- }
83-
8477static void batch_decode (llama_context * ctx, llama_batch_ext * batch, float * output, int n_seq, int n_embd, int embd_norm = 2 ) {
8578 const llama_model * model = llama_get_model (ctx);
8679
@@ -213,7 +206,7 @@ int main(int argc, char ** argv) {
213206
214207 // initialize batch
215208 const int n_chunks = chunks.size ();
216- llama_batch_ext * batch = llama_batch_ext_init (ctx);
209+ llama_batch_ext_ptr batch (ctx);
217210
218211 // allocate output
219212 const int n_embd = llama_model_n_embd (model);
@@ -230,21 +223,21 @@ int main(int argc, char ** argv) {
230223 const uint64_t n_toks = inp.size ();
231224
232225 // encode if at capacity
233- if (llama_batch_ext_get_n_tokens ( batch) + n_toks > n_batch) {
234- batch_decode (ctx, batch, emb + p * n_embd, s, n_embd);
235- llama_batch_ext_clear ( batch);
226+ if (batch. n_tokens ( ) + n_toks > n_batch) {
227+ batch_decode (ctx, batch. get () , emb + p * n_embd, s, n_embd);
228+ batch. clear ( );
236229
237230 p += s;
238231 s = 0 ;
239232 }
240233
241234 // add to batch
242- batch_add_seq ( batch, inp , s);
235+ batch. add_seq (inp, 0 , s, true );
243236 s += 1 ;
244237 }
245238
246239 // final batch
247- batch_decode (ctx, batch, emb + p * n_embd, s, n_embd);
240+ batch_decode (ctx, batch. get () , emb + p * n_embd, s, n_embd);
248241
249242 // save embeddings to chunks
250243 for (int i = 0 ; i < n_chunks; i++) {
@@ -253,7 +246,7 @@ int main(int argc, char ** argv) {
253246 chunks[i].tokens .clear ();
254247 }
255248
256- llama_batch_ext * query_batch = llama_batch_ext_init (ctx);
249+ llama_batch_ext_ptr query_batch (ctx);
257250
258251 // start loop, receive query and return top k similar chunks based on cosine similarity
259252 std::string query;
@@ -262,12 +255,12 @@ int main(int argc, char ** argv) {
262255 std::getline (std::cin, query);
263256 std::vector<int32_t > query_tokens = common_tokenize (ctx, query, true );
264257
265- batch_add_seq (query_batch, query_tokens , 0 );
258+ batch. add_seq (query_tokens, 0 , 0 , true );
266259
267260 std::vector<float > query_emb (n_embd, 0 );
268- batch_decode (ctx, query_batch, query_emb.data (), 1 , n_embd);
261+ batch_decode (ctx, query_batch. get () , query_emb.data (), 1 , n_embd);
269262
270- llama_batch_ext_clear ( query_batch);
263+ query_batch. clear ( );
271264
272265 // compute cosine similarities
273266 {
@@ -296,9 +289,6 @@ int main(int argc, char ** argv) {
296289 LOG (" \n " );
297290 llama_perf_context_print (ctx);
298291
299- llama_batch_ext_free (batch);
300- llama_batch_ext_free (query_batch);
301-
302292 // clean up
303293 llama_backend_free ();
304294}
0 commit comments