@@ -81,14 +81,14 @@ static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & toke
8181 }
8282}
8383
84- static void batch_decode (llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd) {
84+ static void batch_encode (llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd) {
8585 // clear previous kv_cache values (irrelevant for embeddings)
8686 llama_kv_self_clear (ctx);
8787
8888 // run model
8989 LOG_INF (" %s: n_tokens = %d, n_seq = %d\n " , __func__, batch.n_tokens , n_seq);
90- if (llama_decode (ctx, batch) < 0 ) {
91- LOG_ERR (" %s : failed to decode \n " , __func__);
90+ if (llama_encode (ctx, batch) < 0 ) {
91+ LOG_ERR (" %s : failed to encode \n " , __func__);
9292 }
9393
9494 for (int i = 0 ; i < batch.n_tokens ; i++) {
@@ -233,7 +233,7 @@ int main(int argc, char ** argv) {
233233 // encode if at capacity
234234 if (batch.n_tokens + n_toks > n_batch) {
235235 float * out = emb + p * n_embd;
236- batch_decode (ctx, batch, out, s, n_embd);
236+ batch_encode (ctx, batch, out, s, n_embd);
237237 common_batch_clear (batch);
238238 p += s;
239239 s = 0 ;
@@ -246,7 +246,7 @@ int main(int argc, char ** argv) {
246246
247247 // final batch
248248 float * out = emb + p * n_embd;
249- batch_decode (ctx, batch, out, s, n_embd);
249+ batch_encode (ctx, batch, out, s, n_embd);
250250
251251 // save embeddings to chunks
252252 for (int i = 0 ; i < n_chunks; i++) {
@@ -267,7 +267,7 @@ int main(int argc, char ** argv) {
267267 batch_add_seq (query_batch, query_tokens, 0 );
268268
269269 std::vector<float > query_emb (n_embd, 0 );
270- batch_decode (ctx, query_batch, query_emb.data (), 1 , n_embd);
270+ batch_encode (ctx, query_batch, query_emb.data (), 1 , n_embd);
271271
272272 common_batch_clear (query_batch);
273273
0 commit comments