@@ -53,28 +53,37 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu
5353 }
5454 }
5555
56- for (int i = 0 ; i < batch.n_tokens ; i++) {
57- if (!batch.logits [i]) {
58- continue ;
59- }
56+ const float * embd = nullptr ;
57+ int embd_pos = 0 ;
6058
61- const float * embd = nullptr ;
62- int embd_pos = 0 ;
59+ if (pooling_type == LLAMA_POOLING_TYPE_NONE)
60+ {
61+ for (int i = 0 ; i < batch.n_tokens ; i++)
62+ {
63+ if (!batch.logits [i]) {
64+ continue ;
65+ }
6366
64- if (pooling_type == LLAMA_POOLING_TYPE_NONE) {
65- // try to get token embeddings
6667 embd = llama_get_embeddings_ith (ctx, i);
6768 embd_pos = i;
6869 GGML_ASSERT (embd != NULL && " failed to get token embeddings" );
69- } else {
70- // try to get sequence embeddings - supported only when pooling_type is not NONE
71- embd = llama_get_embeddings_seq (ctx, batch.seq_id [i][0 ]);
72- embd_pos = batch.seq_id [i][0 ];
73- GGML_ASSERT (embd != NULL && " failed to get sequence embeddings" );
70+
71+ float * out = output + embd_pos * n_embd;
72+ common_embd_normalize (embd, out, n_embd, embd_norm);
7473 }
74+ }
7575
76- float * out = output + embd_pos * n_embd;
77- common_embd_normalize (embd, out, n_embd, embd_norm);
76+ else
77+ {
78+ for (int i = 0 ; i < n_seq; i++)
79+ {
80+ embd = llama_get_embeddings_seq (ctx, i);
81+ embd_pos = i;
82+ GGML_ASSERT (embd != NULL && " failed to get sequence embeddings" );
83+
84+ float * out = output + embd_pos * n_embd;
85+ common_embd_normalize (embd, out, n_embd, embd_norm);
86+ }
7887 }
7988}
8089
0 commit comments