@@ -26,56 +26,52 @@ static std::vector<std::string> split_lines(const std::string & s, const std::st
2626 return lines;
2727}
2828
29- static void batch_add_seq (common_batch & batch, const std::vector<int32_t > & tokens, llama_seq_id seq_id) {
29+ static void batch_add_seq (llama_batch_ext * batch, const std::vector<int32_t > & tokens, llama_seq_id seq_id) {
3030 size_t n_tokens = tokens.size ();
3131 for (size_t i = 0 ; i < n_tokens; i++) {
32- batch. add_text ( tokens[i], i, seq_id, true );
32+ llama_batch_ext_add_text (batch, tokens[i], i, & seq_id, 1 , true );
3333 }
3434}
3535
36- static void batch_decode (llama_context * ctx, common_batch & batch, float * output, int n_seq, int n_embd, int embd_norm) {
36+ static void batch_decode (llama_context * ctx, llama_batch_ext * batch, float * output, int n_seq, int n_embd, int embd_norm) {
3737 const enum llama_pooling_type pooling_type = llama_pooling_type (ctx);
38- const struct llama_model * model = llama_get_model (ctx);
38+ const llama_model * model = llama_get_model (ctx);
3939
4040 // clear previous kv_cache values (irrelevant for embeddings)
4141 llama_kv_self_clear (ctx);
4242
43+ const int n_tokens = llama_batch_ext_get_n_tokens (batch);
44+
4345 // run model
44- LOG_INF (" %s: n_tokens = %d, n_seq = %d\n " , __func__, llama_batch_ext_get_n_tokens (batch. get ()) , n_seq);
46+ LOG_INF (" %s: n_tokens = %d, n_seq = %d\n " , __func__, n_tokens , n_seq);
4547 if (llama_model_has_encoder (model) && !llama_model_has_decoder (model)) {
4648 // encoder-only model
47- if (llama_encode_ext (ctx, batch. get () ) < 0 ) {
49+ if (llama_encode_ext (ctx, batch) < 0 ) {
4850 LOG_ERR (" %s : failed to encode\n " , __func__);
4951 }
5052 } else if (!llama_model_has_encoder (model) && llama_model_has_decoder (model)) {
5153 // decoder-only model
52- if (llama_decode_ext (ctx, batch. get () ) < 0 ) {
54+ if (llama_decode_ext (ctx, batch) < 0 ) {
5355 LOG_ERR (" %s : failed to decode\n " , __func__);
5456 }
5557 }
5658
57- for (int i = 0 ; i < llama_batch_ext_get_n_tokens (batch.get ()); i++) {
58- if (!batch.tokens [i].logits ) {
59- continue ;
60- }
61-
62- const float * embd = nullptr ;
63- int embd_pos = 0 ;
64-
65- if (pooling_type == LLAMA_POOLING_TYPE_NONE) {
66- // try to get token embeddings
67- embd = llama_get_embeddings_ith (ctx, i);
68- embd_pos = i;
59+ if (pooling_type == LLAMA_POOLING_TYPE_NONE) {
60+ for (int i = 0 ; i < n_tokens; i++) {
61+ const float * embd = llama_get_embeddings_ith (ctx, i);
6962 GGML_ASSERT (embd != NULL && " failed to get token embeddings" );
70- } else {
71- // try to get sequence embeddings - supported only when pooling_type is not NONE
72- embd = llama_get_embeddings_seq (ctx, batch.tokens [i].seq_id );
73- embd_pos = batch.tokens [i].seq_id ;
74- GGML_ASSERT (embd != NULL && " failed to get sequence embeddings" );
63+
64+ float * out = output + i * n_embd;
65+ common_embd_normalize (embd, out, n_embd, embd_norm);
7566 }
67+ } else {
68+ for (int s = 0 ; s < n_seq; s++) {
69+ const float * embd = llama_get_embeddings_seq (ctx, s);
70+ GGML_ASSERT (embd != NULL && " failed to get sequence embeddings" );
7671
77- float * out = output + embd_pos * n_embd;
78- common_embd_normalize (embd, out, n_embd, embd_norm);
72+ float * out = output + s * n_embd;
73+ common_embd_normalize (embd, out, n_embd, embd_norm);
74+ }
7975 }
8076}
8177
@@ -171,7 +167,7 @@ int main(int argc, char ** argv) {
171167
172168 // initialize batch
173169 const int n_prompts = prompts.size ();
174- struct common_batch batch = common_batch (n_batch, 1 );
170+ llama_batch_ext * batch = llama_batch_ext_init (n_batch, 1 );
175171
176172 // count number of embeddings
177173 int n_embd_count = 0 ;
@@ -198,12 +194,12 @@ int main(int argc, char ** argv) {
198194 const uint64_t n_toks = inp.size ();
199195
200196 // encode if at capacity
201- if (batch.get_n_tokens () + n_toks > n_batch) {
202- float * out = emb + e * n_embd;
203- batch_decode (ctx, batch, out, s, n_embd, params.embd_normalize );
204- e += pooling_type == LLAMA_POOLING_TYPE_NONE ? batch.get_n_tokens () : s;
197+ if (llama_batch_ext_get_n_tokens (batch) + n_toks > n_batch) {
198+ batch_decode (ctx, batch, emb + e * n_embd, s, n_embd, params.embd_normalize );
199+ llama_batch_ext_clear (batch);
200+
201+ e += pooling_type == LLAMA_POOLING_TYPE_NONE ? llama_batch_ext_get_n_tokens (batch) : s;
205202 s = 0 ;
206- batch.clear ();
207203 }
208204
209205 // add to batch
@@ -212,8 +208,7 @@ int main(int argc, char ** argv) {
212208 }
213209
214210 // final batch
215- float * out = emb + e * n_embd;
216- batch_decode (ctx, batch, out, s, n_embd, params.embd_normalize );
211+ batch_decode (ctx, batch, emb + e * n_embd, s, n_embd, params.embd_normalize );
217212
218213 if (params.embd_out .empty ()) {
219214 LOG (" \n " );
@@ -318,6 +313,8 @@ int main(int argc, char ** argv) {
318313 LOG (" \n " );
319314 llama_perf_context_print (ctx);
320315
316+ llama_batch_ext_free (batch);
317+
321318 // clean up
322319 llama_backend_free ();
323320
0 commit comments