1515#include < string>
1616#include < vector>
1717
18+
1819#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
1920#include < signal.h>
2021#include < unistd.h>
@@ -47,6 +48,7 @@ static void print_usage(int argc, char ** argv) {
4748 LOG (" \n example usage:\n " );
4849 LOG (" \n text generation: %s -m your_model.gguf -p \" I believe the meaning of life is\" -n 128 -no-cnv\n " , argv[0 ]);
4950 LOG (" \n chat (conversation): %s -m your_model.gguf -sys \" You are a helpful assistant\"\n " , argv[0 ]);
51+ LOG (" \n embeddings: %s -m your_model.gguf --embedding -p \" Hello world\"\n " , argv[0 ]);
5052 LOG (" \n " );
5153}
5254
@@ -83,6 +85,78 @@ static void sigint_handler(int signo) {
8385}
8486#endif
8587
88+ // Function to generate embeddings
89+ static bool generate_embeddings (llama_context * ctx, const std::vector<llama_token> & tokens) {
90+ // Make sure we have a valid context
91+ if (ctx == nullptr ) {
92+ LOG_ERR (" %s: error: context is null\n " , __func__);
93+ return false ;
94+ }
95+
96+ // Create a batch with the input tokens
97+ llama_batch batch = llama_batch_init (tokens.size (), 0 , 1 );
98+ for (size_t i = 0 ; i < tokens.size (); ++i) {
99+ common_batch_add (batch, tokens[i], i, { 0 }, true );
100+ }
101+
102+ // Process the batch
103+ if (llama_decode (ctx, batch)) {
104+ LOG_ERR (" %s: failed to decode\n " , __func__);
105+ llama_batch_free (batch);
106+ return false ;
107+ }
108+
109+ // Get embeddings
110+ const int n_embd = llama_model_n_embd (llama_get_model (ctx));
111+ std::vector<float > embeddings;
112+
113+ // Determine if we're using sequence-level or token-level embeddings
114+ enum llama_pooling_type pooling_type = llama_pooling_type (ctx);
115+ if (pooling_type != LLAMA_POOLING_TYPE_NONE) {
116+ // Sequence-level embedding
117+ const float * embd = llama_get_embeddings_seq (ctx, 0 );
118+ if (embd == nullptr ) {
119+ LOG_ERR (" %s: failed to get sequence embeddings\n " , __func__);
120+ llama_batch_free (batch);
121+ return false ;
122+ }
123+
124+ embeddings.assign (embd, embd + n_embd);
125+
126+ // Output the embeddings
127+ LOG_INF (" Sequence embedding (dimension: %d):\n " , n_embd);
128+ printf (" [\n " );
129+ for (int i = 0 ; i < n_embd; ++i) {
130+ printf (" %f%s\n " , embeddings[i], i < n_embd - 1 ? " ," : " " );
131+ }
132+ printf (" ]\n " );
133+ } else {
134+ // Token-level embeddings - print for each token
135+ LOG_INF (" Token-level embeddings (dimension: %d):\n " , n_embd);
136+ printf (" [\n " );
137+ for (size_t t = 0 ; t < tokens.size (); ++t) {
138+ const float * embd = llama_get_embeddings_ith (ctx, t);
139+ if (embd == nullptr ) {
140+ LOG_ERR (" %s: failed to get token embeddings for token %zu\n " , __func__, t);
141+ continue ;
142+ }
143+
144+ // Get the token string representation for reference
145+ std::string token_str = common_token_to_piece (ctx, tokens[t]);
146+ printf (" // Token %zu: '%s'\n " , t, token_str.c_str ());
147+ printf (" [\n " );
148+ for (int i = 0 ; i < n_embd; ++i) {
149+ printf (" %f%s\n " , embd[i], i < n_embd - 1 ? " ," : " " );
150+ }
151+ printf (" ]%s\n " , t < tokens.size () - 1 ? " ," : " " );
152+ }
153+ printf (" ]\n " );
154+ }
155+
156+ llama_batch_free (batch);
157+ return true ;
158+ }
159+
86160int main (int argc, char ** argv) {
87161 common_params params;
88162 g_params = ¶ms;
@@ -107,14 +181,6 @@ int main(int argc, char ** argv) {
107181 return 0 ;
108182 }
109183
110- if (params.embedding ) {
111- LOG_ERR (" ************\n " );
112- LOG_ERR (" %s: please use the 'embedding' tool for embedding calculations\n " , __func__);
113- LOG_ERR (" ************\n\n " );
114-
115- return 0 ;
116- }
117-
118184 if (params.n_ctx != 0 && params.n_ctx < 8 ) {
119185 LOG_WRN (" %s: warning: minimum context size is 8, using minimum size.\n " , __func__);
120186 params.n_ctx = 8 ;
@@ -234,6 +300,53 @@ int main(int argc, char ** argv) {
234300 LOG_INF (" \n " );
235301 }
236302
303+ // For embedding mode, we only need to process the prompt, generate embeddings, and exit
304+ if (params.embedding ) {
305+ // Make sure we have a prompt
306+ if (params.prompt .empty ()) {
307+ LOG_ERR (" %s: error: prompt is required for embedding\n " , __func__);
308+ return 1 ;
309+ }
310+
311+ // Enable embeddings for the context
312+ llama_set_embeddings (ctx, true );
313+
314+ // Tokenize the prompt
315+ const bool add_bos = llama_vocab_get_add_bos (vocab) && !params.use_jinja ;
316+ std::vector<llama_token> tokens = common_tokenize (ctx, params.prompt , add_bos, true );
317+
318+ if (tokens.empty ()) {
319+ LOG_ERR (" %s: error: failed to tokenize prompt\n " , __func__);
320+ return 1 ;
321+ }
322+
323+ LOG_INF (" %s: generating embeddings for %zu tokens\n " , __func__, tokens.size ());
324+ LOG_INF (" %s: prompt: '%s'\n " , __func__, params.prompt .c_str ());
325+
326+ if (params.verbose_prompt ) {
327+ LOG_INF (" %s: tokens: " , __func__);
328+ for (size_t i = 0 ; i < tokens.size (); ++i) {
329+ LOG_INF (" %d ('%s') " , tokens[i], common_token_to_piece (ctx, tokens[i]).c_str ());
330+ }
331+ LOG_INF (" \n " );
332+ }
333+
334+ // Generate and print embeddings
335+ if (!generate_embeddings (ctx, tokens)) {
336+ LOG_ERR (" %s: error: failed to generate embeddings\n " , __func__);
337+ return 1 ;
338+ }
339+
340+ // Clean up and exit
341+ ggml_threadpool_free_fn (threadpool);
342+ if (threadpool_batch) {
343+ ggml_threadpool_free_fn (threadpool_batch);
344+ }
345+ llama_backend_free ();
346+
347+ return 0 ;
348+ }
349+
237350 std::string path_session = params.path_prompt_cache ;
238351 std::vector<llama_token> session_tokens;
239352
0 commit comments