Skip to content

Commit 44da159

Browse files
committed
Add embedding support to llama-cli
1 parent 5215b91 commit 44da159

File tree

2 files changed

+122
-9
lines changed

2 files changed

+122
-9
lines changed

common/arg.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2720,7 +2720,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
27202720
[](common_params & params) {
27212721
params.embedding = true;
27222722
}
2723-
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_EMBEDDINGS"));
2723+
).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_EMBEDDINGS"));
27242724
add_opt(common_arg(
27252725
{"--reranking", "--rerank"},
27262726
string_format("enable reranking endpoint on server (default: %s)", params.reranking ? "enabled" : "disabled"),

tools/main/main.cpp

Lines changed: 121 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include <string>
1616
#include <vector>
1717

18+
1819
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
1920
#include <signal.h>
2021
#include <unistd.h>
@@ -47,6 +48,7 @@ static void print_usage(int argc, char ** argv) {
4748
LOG("\nexample usage:\n");
4849
LOG("\n text generation: %s -m your_model.gguf -p \"I believe the meaning of life is\" -n 128 -no-cnv\n", argv[0]);
4950
LOG("\n chat (conversation): %s -m your_model.gguf -sys \"You are a helpful assistant\"\n", argv[0]);
51+
LOG("\n embeddings: %s -m your_model.gguf --embedding -p \"Hello world\"\n", argv[0]);
5052
LOG("\n");
5153
}
5254

@@ -83,6 +85,78 @@ static void sigint_handler(int signo) {
8385
}
8486
#endif
8587

88+
// Function to generate embeddings
89+
static bool generate_embeddings(llama_context * ctx, const std::vector<llama_token> & tokens) {
90+
// Make sure we have a valid context
91+
if (ctx == nullptr) {
92+
LOG_ERR("%s: error: context is null\n", __func__);
93+
return false;
94+
}
95+
96+
// Create a batch with the input tokens
97+
llama_batch batch = llama_batch_init(tokens.size(), 0, 1);
98+
for (size_t i = 0; i < tokens.size(); ++i) {
99+
common_batch_add(batch, tokens[i], i, { 0 }, true);
100+
}
101+
102+
// Process the batch
103+
if (llama_decode(ctx, batch)) {
104+
LOG_ERR("%s: failed to decode\n", __func__);
105+
llama_batch_free(batch);
106+
return false;
107+
}
108+
109+
// Get embeddings
110+
const int n_embd = llama_model_n_embd(llama_get_model(ctx));
111+
std::vector<float> embeddings;
112+
113+
// Determine if we're using sequence-level or token-level embeddings
114+
enum llama_pooling_type pooling_type = llama_pooling_type(ctx);
115+
if (pooling_type != LLAMA_POOLING_TYPE_NONE) {
116+
// Sequence-level embedding
117+
const float * embd = llama_get_embeddings_seq(ctx, 0);
118+
if (embd == nullptr) {
119+
LOG_ERR("%s: failed to get sequence embeddings\n", __func__);
120+
llama_batch_free(batch);
121+
return false;
122+
}
123+
124+
embeddings.assign(embd, embd + n_embd);
125+
126+
// Output the embeddings
127+
LOG_INF("Sequence embedding (dimension: %d):\n", n_embd);
128+
printf("[\n");
129+
for (int i = 0; i < n_embd; ++i) {
130+
printf(" %f%s\n", embeddings[i], i < n_embd - 1 ? "," : "");
131+
}
132+
printf("]\n");
133+
} else {
134+
// Token-level embeddings - print for each token
135+
LOG_INF("Token-level embeddings (dimension: %d):\n", n_embd);
136+
printf("[\n");
137+
for (size_t t = 0; t < tokens.size(); ++t) {
138+
const float * embd = llama_get_embeddings_ith(ctx, t);
139+
if (embd == nullptr) {
140+
LOG_ERR("%s: failed to get token embeddings for token %zu\n", __func__, t);
141+
continue;
142+
}
143+
144+
// Get the token string representation for reference
145+
std::string token_str = common_token_to_piece(ctx, tokens[t]);
146+
printf(" // Token %zu: '%s'\n", t, token_str.c_str());
147+
printf(" [\n");
148+
for (int i = 0; i < n_embd; ++i) {
149+
printf(" %f%s\n", embd[i], i < n_embd - 1 ? "," : "");
150+
}
151+
printf(" ]%s\n", t < tokens.size() - 1 ? "," : "");
152+
}
153+
printf("]\n");
154+
}
155+
156+
llama_batch_free(batch);
157+
return true;
158+
}
159+
86160
int main(int argc, char ** argv) {
87161
common_params params;
88162
g_params = &params;
@@ -107,14 +181,6 @@ int main(int argc, char ** argv) {
107181
return 0;
108182
}
109183

110-
if (params.embedding) {
111-
LOG_ERR("************\n");
112-
LOG_ERR("%s: please use the 'embedding' tool for embedding calculations\n", __func__);
113-
LOG_ERR("************\n\n");
114-
115-
return 0;
116-
}
117-
118184
if (params.n_ctx != 0 && params.n_ctx < 8) {
119185
LOG_WRN("%s: warning: minimum context size is 8, using minimum size.\n", __func__);
120186
params.n_ctx = 8;
@@ -234,6 +300,53 @@ int main(int argc, char ** argv) {
234300
LOG_INF("\n");
235301
}
236302

303+
// For embedding mode, we only need to process the prompt, generate embeddings, and exit
304+
if (params.embedding) {
305+
// Make sure we have a prompt
306+
if (params.prompt.empty()) {
307+
LOG_ERR("%s: error: prompt is required for embedding\n", __func__);
308+
return 1;
309+
}
310+
311+
// Enable embeddings for the context
312+
llama_set_embeddings(ctx, true);
313+
314+
// Tokenize the prompt
315+
const bool add_bos = llama_vocab_get_add_bos(vocab) && !params.use_jinja;
316+
std::vector<llama_token> tokens = common_tokenize(ctx, params.prompt, add_bos, true);
317+
318+
if (tokens.empty()) {
319+
LOG_ERR("%s: error: failed to tokenize prompt\n", __func__);
320+
return 1;
321+
}
322+
323+
LOG_INF("%s: generating embeddings for %zu tokens\n", __func__, tokens.size());
324+
LOG_INF("%s: prompt: '%s'\n", __func__, params.prompt.c_str());
325+
326+
if (params.verbose_prompt) {
327+
LOG_INF("%s: tokens: ", __func__);
328+
for (size_t i = 0; i < tokens.size(); ++i) {
329+
LOG_INF("%d ('%s') ", tokens[i], common_token_to_piece(ctx, tokens[i]).c_str());
330+
}
331+
LOG_INF("\n");
332+
}
333+
334+
// Generate and print embeddings
335+
if (!generate_embeddings(ctx, tokens)) {
336+
LOG_ERR("%s: error: failed to generate embeddings\n", __func__);
337+
return 1;
338+
}
339+
340+
// Clean up and exit
341+
ggml_threadpool_free_fn(threadpool);
342+
if (threadpool_batch) {
343+
ggml_threadpool_free_fn(threadpool_batch);
344+
}
345+
llama_backend_free();
346+
347+
return 0;
348+
}
349+
237350
std::string path_session = params.path_prompt_cache;
238351
std::vector<llama_token> session_tokens;
239352

0 commit comments

Comments
 (0)