@@ -74,55 +74,38 @@ static std::vector<chunk> chunk_file(const std::string & filename, int chunk_siz
7474    return  chunks;
7575}
7676
77- static  void  batch_add_seq (common_batch &  batch, const  std::vector<int32_t > & tokens, llama_seq_id seq_id) {
78-     size_t  n_tokens = tokens.size ();
77+ static  void  batch_add_seq (llama_batch_ext *  batch, const  std::vector<int32_t > & tokens, llama_seq_id seq_id) {
78+     const   size_t  n_tokens = tokens.size ();
7979    for  (size_t  i = 0 ; i < n_tokens; i++) {
80-         batch. add_text ( tokens[i], i, seq_id, true );
80+         llama_batch_ext_add_text (batch,  tokens[i], i, & seq_id,  1 , true );
8181    }
8282}
8383
84- static  void  batch_decode (llama_context * ctx, common_batch & batch, float  * output, int  n_seq, int  n_embd, int  embd_norm = 2 ) {
85-     const  enum  llama_pooling_type pooling_type = llama_pooling_type (ctx);
84+ static  void  batch_decode (llama_context * ctx, llama_batch_ext * batch, float  * output, int  n_seq, int  n_embd, int  embd_norm = 2 ) {
8685    const  struct  llama_model  * model = llama_get_model (ctx);
8786
8887    //  clear previous kv_cache values (irrelevant for embeddings)
8988    llama_kv_self_clear (ctx);
9089
9190    //  run model
92-     LOG_INF (" %s: n_tokens = %d, n_seq = %d\n "  , __func__, llama_batch_ext_get_n_tokens (batch. get () ), n_seq);
91+     LOG_INF (" %s: n_tokens = %d, n_seq = %d\n "  , __func__, llama_batch_ext_get_n_tokens (batch), n_seq);
9392    if  (llama_model_has_encoder (model) && !llama_model_has_decoder (model)) {
9493        //  encoder-only model
95-         if  (llama_encode_ext (ctx, batch. get () ) < 0 ) {
94+         if  (llama_encode_ext (ctx, batch) < 0 ) {
9695            LOG_ERR (" %s : failed to encode\n "  , __func__);
9796        }
9897    } else  if  (!llama_model_has_encoder (model) && llama_model_has_decoder (model)) {
9998        //  decoder-only model
100-         if  (llama_decode_ext (ctx, batch. get () ) < 0 ) {
99+         if  (llama_decode_ext (ctx, batch) < 0 ) {
101100            LOG_ERR (" %s : failed to decode\n "  , __func__);
102101        }
103102    }
104103
105-     for  (int  i = 0 ; i < llama_batch_ext_get_n_tokens (batch.get ()); i++) {
106-         if  (!batch.tokens [i].logits ) {
107-             continue ;
108-         }
109- 
110-         const  float  * embd = nullptr ;
111-         int  embd_pos = 0 ;
112- 
113-         if  (pooling_type == LLAMA_POOLING_TYPE_NONE) {
114-             //  try to get token embeddings
115-             embd = llama_get_embeddings_ith (ctx, i);
116-             embd_pos = i;
117-             GGML_ASSERT (embd != NULL  && " failed to get token embeddings"  );
118-         } else  {
119-             //  try to get sequence embeddings - supported only when pooling_type is not NONE
120-             embd = llama_get_embeddings_seq (ctx, batch.tokens [i].seq_id );
121-             embd_pos = batch.tokens [i].seq_id ;
122-             GGML_ASSERT (embd != NULL  && " failed to get sequence embeddings"  );
123-         }
104+     for  (int  s = 0 ; s < n_seq; s++) {
105+         const  float  * embd = llama_get_embeddings_seq (ctx, s);
106+         GGML_ASSERT (embd != NULL  && " failed to get sequence embeddings"  );
124107
125-         float  * out = output + embd_pos  * n_embd;
108+         float  * out = output + s  * n_embd;
126109        common_embd_normalize (embd, out, n_embd, embd_norm);
127110    }
128111}
@@ -230,7 +213,7 @@ int main(int argc, char ** argv) {
230213
231214    //  initialize batch
232215    const  int  n_chunks = chunks.size ();
233-     struct   common_batch  batch = common_batch (n_batch, 1 );
216+     llama_batch_ext *  batch = llama_batch_ext_init (n_batch, 1 );
234217
235218    //  allocate output
236219    const  int  n_embd = llama_model_n_embd (model);
@@ -247,10 +230,10 @@ int main(int argc, char ** argv) {
247230        const  uint64_t  n_toks = inp.size ();
248231
249232        //  encode if at capacity
250-         if  (llama_batch_ext_get_n_tokens (batch. get () ) + n_toks > n_batch) {
251-             float  * out =  emb + p * n_embd;
252-             batch_decode (ctx,  batch, out, s, n_embd );
253-             batch. clear (); 
233+         if  (llama_batch_ext_get_n_tokens (batch) + n_toks > n_batch) {
234+             batch_decode (ctx, batch,  emb + p * n_embd, s, n_embd) ;
235+             llama_batch_ext_clear ( batch);
236+ 
254237            p += s;
255238            s = 0 ;
256239        }
@@ -261,8 +244,7 @@ int main(int argc, char ** argv) {
261244    }
262245
263246    //  final batch
264-     float  * out = emb + p * n_embd;
265-     batch_decode (ctx, batch, out, s, n_embd);
247+     batch_decode (ctx, batch, emb + p * n_embd, s, n_embd);
266248
267249    //  save embeddings to chunks
268250    for  (int  i = 0 ; i < n_chunks; i++) {
@@ -271,7 +253,7 @@ int main(int argc, char ** argv) {
271253        chunks[i].tokens .clear ();
272254    }
273255
274-     struct   common_batch  query_batch = common_batch (n_batch, 1 );
256+     llama_batch_ext *  query_batch = llama_batch_ext_init (n_batch, 1 );
275257
276258    //  start loop, receive query and return top k similar chunks based on cosine similarity
277259    std::string query;
@@ -285,7 +267,7 @@ int main(int argc, char ** argv) {
285267        std::vector<float > query_emb (n_embd, 0 );
286268        batch_decode (ctx, query_batch, query_emb.data (), 1 , n_embd);
287269
288-         query_batch. clear ( );
270+         llama_batch_ext_clear (query_batch );
289271
290272        //  compute cosine similarities
291273        {
@@ -314,6 +296,9 @@ int main(int argc, char ** argv) {
314296    LOG (" \n "  );
315297    llama_perf_context_print (ctx);
316298
299+     llama_batch_ext_free (batch);
300+     llama_batch_ext_free (query_batch);
301+ 
317302    //  clean up
318303    llama_backend_free ();
319304}
0 commit comments