@@ -74,55 +74,38 @@ static std::vector<chunk> chunk_file(const std::string & filename, int chunk_siz
7474 return chunks;
7575}
7676
77- static void batch_add_seq (common_batch & batch, const std::vector<int32_t > & tokens, llama_seq_id seq_id) {
78- size_t n_tokens = tokens.size ();
77+ static void batch_add_seq (llama_batch_ext * batch, const std::vector<int32_t > & tokens, llama_seq_id seq_id) {
78+ const size_t n_tokens = tokens.size ();
7979 for (size_t i = 0 ; i < n_tokens; i++) {
80- batch. add_text ( tokens[i], i, seq_id, true );
80+ llama_batch_ext_add_text (batch, tokens[i], i, & seq_id, 1 , true );
8181 }
8282}
8383
84- static void batch_decode (llama_context * ctx, common_batch & batch, float * output, int n_seq, int n_embd, int embd_norm = 2 ) {
85- const enum llama_pooling_type pooling_type = llama_pooling_type (ctx);
84+ static void batch_decode (llama_context * ctx, llama_batch_ext * batch, float * output, int n_seq, int n_embd, int embd_norm = 2 ) {
8685 const struct llama_model * model = llama_get_model (ctx);
8786
8887 // clear previous kv_cache values (irrelevant for embeddings)
8988 llama_kv_self_clear (ctx);
9089
9190 // run model
92- LOG_INF (" %s: n_tokens = %d, n_seq = %d\n " , __func__, llama_batch_ext_get_n_tokens (batch. get () ), n_seq);
91+ LOG_INF (" %s: n_tokens = %d, n_seq = %d\n " , __func__, llama_batch_ext_get_n_tokens (batch), n_seq);
9392 if (llama_model_has_encoder (model) && !llama_model_has_decoder (model)) {
9493 // encoder-only model
95- if (llama_encode_ext (ctx, batch. get () ) < 0 ) {
94+ if (llama_encode_ext (ctx, batch) < 0 ) {
9695 LOG_ERR (" %s : failed to encode\n " , __func__);
9796 }
9897 } else if (!llama_model_has_encoder (model) && llama_model_has_decoder (model)) {
9998 // decoder-only model
100- if (llama_decode_ext (ctx, batch. get () ) < 0 ) {
99+ if (llama_decode_ext (ctx, batch) < 0 ) {
101100 LOG_ERR (" %s : failed to decode\n " , __func__);
102101 }
103102 }
104103
105- for (int i = 0 ; i < llama_batch_ext_get_n_tokens (batch.get ()); i++) {
106- if (!batch.tokens [i].logits ) {
107- continue ;
108- }
109-
110- const float * embd = nullptr ;
111- int embd_pos = 0 ;
112-
113- if (pooling_type == LLAMA_POOLING_TYPE_NONE) {
114- // try to get token embeddings
115- embd = llama_get_embeddings_ith (ctx, i);
116- embd_pos = i;
117- GGML_ASSERT (embd != NULL && " failed to get token embeddings" );
118- } else {
119- // try to get sequence embeddings - supported only when pooling_type is not NONE
120- embd = llama_get_embeddings_seq (ctx, batch.tokens [i].seq_id );
121- embd_pos = batch.tokens [i].seq_id ;
122- GGML_ASSERT (embd != NULL && " failed to get sequence embeddings" );
123- }
104+ for (int s = 0 ; s < n_seq; s++) {
105+ const float * embd = llama_get_embeddings_seq (ctx, s);
106+ GGML_ASSERT (embd != NULL && " failed to get sequence embeddings" );
124107
125- float * out = output + embd_pos * n_embd;
108+ float * out = output + s * n_embd;
126109 common_embd_normalize (embd, out, n_embd, embd_norm);
127110 }
128111}
@@ -230,7 +213,7 @@ int main(int argc, char ** argv) {
230213
231214 // initialize batch
232215 const int n_chunks = chunks.size ();
233- struct common_batch batch = common_batch (n_batch, 1 );
216+ llama_batch_ext * batch = llama_batch_ext_init (n_batch, 1 );
234217
235218 // allocate output
236219 const int n_embd = llama_model_n_embd (model);
@@ -247,10 +230,10 @@ int main(int argc, char ** argv) {
247230 const uint64_t n_toks = inp.size ();
248231
249232 // encode if at capacity
250- if (llama_batch_ext_get_n_tokens (batch. get () ) + n_toks > n_batch) {
251- float * out = emb + p * n_embd;
252- batch_decode (ctx, batch, out, s, n_embd );
253- batch. clear ();
233+ if (llama_batch_ext_get_n_tokens (batch) + n_toks > n_batch) {
234+ batch_decode (ctx, batch, emb + p * n_embd, s, n_embd) ;
235+ llama_batch_ext_clear ( batch);
236+
254237 p += s;
255238 s = 0 ;
256239 }
@@ -261,8 +244,7 @@ int main(int argc, char ** argv) {
261244 }
262245
263246 // final batch
264- float * out = emb + p * n_embd;
265- batch_decode (ctx, batch, out, s, n_embd);
247+ batch_decode (ctx, batch, emb + p * n_embd, s, n_embd);
266248
267249 // save embeddings to chunks
268250 for (int i = 0 ; i < n_chunks; i++) {
@@ -271,7 +253,7 @@ int main(int argc, char ** argv) {
271253 chunks[i].tokens .clear ();
272254 }
273255
274- struct common_batch query_batch = common_batch (n_batch, 1 );
256+ llama_batch_ext * query_batch = llama_batch_ext_init (n_batch, 1 );
275257
276258 // start loop, receive query and return top k similar chunks based on cosine similarity
277259 std::string query;
@@ -285,7 +267,7 @@ int main(int argc, char ** argv) {
285267 std::vector<float > query_emb (n_embd, 0 );
286268 batch_decode (ctx, query_batch, query_emb.data (), 1 , n_embd);
287269
288- query_batch. clear ( );
270+ llama_batch_ext_clear (query_batch );
289271
290272 // compute cosine similarities
291273 {
@@ -314,6 +296,9 @@ int main(int argc, char ** argv) {
314296 LOG (" \n " );
315297 llama_perf_context_print (ctx);
316298
299+ llama_batch_ext_free (batch);
300+ llama_batch_ext_free (query_batch);
301+
317302 // clean up
318303 llama_backend_free ();
319304}
0 commit comments