Skip to content

Commit 1fccfc9

Browse files
committed
Removed unnecessary iteration of batch n_tokens on sequence embeddings generation.
1 parent 9ba399d commit 1fccfc9

File tree

1 file changed

+24
-15
lines changed

1 file changed

+24
-15
lines changed

examples/embedding/embedding.cpp

Lines changed: 24 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -53,28 +53,37 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu
5353
}
5454
}
5555

56-
for (int i = 0; i < batch.n_tokens; i++) {
57-
if (!batch.logits[i]) {
58-
continue;
59-
}
56+
const float* embd = nullptr;
57+
int embd_pos = 0;
6058

61-
const float * embd = nullptr;
62-
int embd_pos = 0;
59+
if(pooling_type == LLAMA_POOLING_TYPE_NONE)
60+
{
61+
for (int i = 0; i < batch.n_tokens; i++)
62+
{
63+
if (!batch.logits[i]) {
64+
continue;
65+
}
6366

64-
if (pooling_type == LLAMA_POOLING_TYPE_NONE) {
65-
// try to get token embeddings
6667
embd = llama_get_embeddings_ith(ctx, i);
6768
embd_pos = i;
6869
GGML_ASSERT(embd != NULL && "failed to get token embeddings");
69-
} else {
70-
// try to get sequence embeddings - supported only when pooling_type is not NONE
71-
embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]);
72-
embd_pos = batch.seq_id[i][0];
73-
GGML_ASSERT(embd != NULL && "failed to get sequence embeddings");
70+
71+
float * out = output + embd_pos * n_embd;
72+
common_embd_normalize(embd, out, n_embd, embd_norm);
7473
}
74+
}
7575

76-
float * out = output + embd_pos * n_embd;
77-
common_embd_normalize(embd, out, n_embd, embd_norm);
76+
else
77+
{
78+
for(int i = 0; i < n_seq; i++)
79+
{
80+
embd = llama_get_embeddings_seq(ctx, i);
81+
embd_pos = i;
82+
GGML_ASSERT(embd != NULL && "failed to get sequence embeddings");
83+
84+
float * out = output + embd_pos * n_embd;
85+
common_embd_normalize(embd, out, n_embd, embd_norm);
86+
}
7887
}
7988
}
8089

0 commit comments

Comments
 (0)