1+ #include " arg.h"
2+ #include " common.h"
3+ #include " log.h"
4+ #include " llama.h"
5+
6+ #include < ctime>
7+ #include < algorithm>
8+
9+ #if defined(_MSC_VER)
10+ #pragma warning(disable: 4244 4267) // possible loss of data
11+ #endif
12+
13+ static std::vector<std::string> split_lines (const std::string & s, const std::string & separator = " \n " ) {
14+ std::vector<std::string> lines;
15+ size_t start = 0 ;
16+ size_t end = s.find (separator);
17+
18+ while (end != std::string::npos) {
19+ lines.push_back (s.substr (start, end - start));
20+ start = end + separator.length ();
21+ end = s.find (separator, start);
22+ }
23+
24+ lines.push_back (s.substr (start)); // Add the last part
25+
26+ return lines;
27+ }
28+
29+ static void batch_add_seq (llama_batch & batch, const std::vector<int32_t > & tokens, llama_seq_id seq_id) {
30+ size_t n_tokens = tokens.size ();
31+ for (size_t i = 0 ; i < n_tokens; i++) {
32+ common_batch_add (batch, tokens[i], i, { seq_id }, true );
33+ }
34+ }
35+
36+ static void batch_decode (llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd, int embd_norm) {
37+ const enum llama_pooling_type pooling_type = llama_pooling_type (ctx);
38+
39+ // clear previous kv_cache values (irrelevant for embeddings)
40+ llama_memory_clear (llama_get_memory (ctx));
41+
42+ // run model
43+ LOG_INF (" %s: n_tokens = %d, n_seq = %d\n " , __func__, batch.n_tokens , n_seq);
44+ if (llama_decode (ctx, batch) < 0 ) {
45+ LOG_ERR (" %s : failed to process\n " , __func__);
46+ }
47+
48+ for (int i = 0 ; i < batch.n_tokens ; i++) {
49+ if (!batch.logits [i]) {
50+ continue ;
51+ }
52+
53+ const float * embd = nullptr ;
54+ int embd_pos = 0 ;
55+
56+ if (pooling_type == LLAMA_POOLING_TYPE_NONE) {
57+ // try to get token embeddings
58+ embd = llama_get_embeddings_ith (ctx, i);
59+ embd_pos = i;
60+ GGML_ASSERT (embd != NULL && " failed to get token embeddings" );
61+ } else {
62+ // try to get sequence embeddings - supported only when pooling_type is not NONE
63+ embd = llama_get_embeddings_seq (ctx, batch.seq_id [i][0 ]);
64+ embd_pos = batch.seq_id [i][0 ];
65+ GGML_ASSERT (embd != NULL && " failed to get sequence embeddings" );
66+ }
67+
68+ float * out = output + embd_pos * n_embd;
69+ common_embd_normalize (embd, out, n_embd, embd_norm);
70+ }
71+ }
72+
73+ int main (int argc, char ** argv) {
74+ common_params params;
75+
76+ if (!common_params_parse (argc, argv, params, LLAMA_EXAMPLE_EMBEDDING)) {
77+ return 1 ;
78+ }
79+
80+ common_init ();
81+
82+ params.embedding = true ;
83+
84+ // utilize the full context
85+ if (params.n_batch < params.n_ctx ) {
86+ LOG_WRN (" %s: setting batch size to %d\n " , __func__, params.n_ctx );
87+ params.n_batch = params.n_ctx ;
88+ }
89+
90+ // For non-causal models, batch size must be equal to ubatch size
91+ params.n_ubatch = params.n_batch ;
92+
93+ llama_backend_init ();
94+ llama_numa_init (params.numa );
95+
96+ // load the model
97+ common_init_result llama_init = common_init_from_params (params);
98+
99+ llama_model * model = llama_init.model .get ();
100+ llama_context * ctx = llama_init.context .get ();
101+
102+ if (model == NULL ) {
103+ LOG_ERR (" %s: unable to load model\n " , __func__);
104+ return 1 ;
105+ }
106+
107+ const llama_vocab * vocab = llama_model_get_vocab (model);
108+
109+ const int n_ctx_train = llama_model_n_ctx_train (model);
110+ const int n_ctx = llama_n_ctx (ctx);
111+
112+ const enum llama_pooling_type pooling_type = llama_pooling_type (ctx);
113+
114+ if (llama_model_has_encoder (model) && llama_model_has_decoder (model)) {
115+ LOG_ERR (" %s: computing embeddings in encoder-decoder models is not supported\n " , __func__);
116+ return 1 ;
117+ }
118+
119+ if (n_ctx > n_ctx_train) {
120+ LOG_WRN (" %s: warning: model was trained on only %d context tokens (%d specified)\n " ,
121+ __func__, n_ctx_train, n_ctx);
122+ }
123+
124+ // print system information
125+ {
126+ LOG_INF (" \n " );
127+ LOG_INF (" %s\n " , common_params_get_system_info (params).c_str ());
128+ }
129+
130+ // split the prompt into lines
131+ std::vector<std::string> prompts = split_lines (params.prompt , params.embd_sep );
132+
133+ // max batch size
134+ const uint64_t n_batch = params.n_batch ;
135+
136+ // tokenize the prompts and trim
137+ std::vector<std::vector<int32_t >> inputs;
138+ for (const auto & prompt : prompts) {
139+ auto inp = common_tokenize (ctx, prompt, true , true );
140+ if (inp.size () > n_batch) {
141+ LOG_ERR (" %s: number of tokens in input line (%lld) exceeds batch size (%lld), increase batch size and re-run\n " ,
142+ __func__, (long long int ) inp.size (), (long long int ) n_batch);
143+ return 1 ;
144+ }
145+ inputs.push_back (inp);
146+ }
147+
148+ // check if the last token is SEP
149+ // it should be automatically added by the tokenizer when 'tokenizer.ggml.add_eos_token' is set to 'true'
150+ for (auto & inp : inputs) {
151+ if (inp.empty () || inp.back () != llama_vocab_sep (vocab)) {
152+ LOG_WRN (" %s: last token in the prompt is not SEP\n " , __func__);
153+ LOG_WRN (" %s: 'tokenizer.ggml.add_eos_token' should be set to 'true' in the GGUF header\n " , __func__);
154+ }
155+ }
156+
157+ // tokenization stats
158+ if (params.verbose_prompt ) {
159+ for (int i = 0 ; i < (int ) inputs.size (); i++) {
160+ LOG_INF (" %s: prompt %d: '%s'\n " , __func__, i, prompts[i].c_str ());
161+ LOG_INF (" %s: number of tokens in prompt = %zu\n " , __func__, inputs[i].size ());
162+ for (int j = 0 ; j < (int ) inputs[i].size (); j++) {
163+ LOG (" %6d -> '%s'\n " , inputs[i][j], common_token_to_piece (ctx, inputs[i][j]).c_str ());
164+ }
165+ LOG (" \n\n " );
166+ }
167+ }
168+
169+ // initialize batch
170+ const int n_prompts = prompts.size ();
171+ struct llama_batch batch = llama_batch_init (n_batch, 0 , 1 );
172+
173+ // count number of embeddings
174+ int n_embd_count = 0 ;
175+ if (pooling_type == LLAMA_POOLING_TYPE_NONE) {
176+ for (int k = 0 ; k < n_prompts; k++) {
177+ n_embd_count += inputs[k].size ();
178+ }
179+ } else {
180+ n_embd_count = n_prompts;
181+ }
182+
183+ // allocate output
184+ const int n_embd = llama_model_n_embd (model);
185+ std::vector<float > embeddings (n_embd_count * n_embd, 0 );
186+ float * emb = embeddings.data ();
187+
188+ // break into batches
189+ int e = 0 ; // number of embeddings already stored
190+ int s = 0 ; // number of prompts in current batch
191+ for (int k = 0 ; k < n_prompts; k++) {
192+ // clamp to n_batch tokens
193+ auto & inp = inputs[k];
194+
195+ const uint64_t n_toks = inp.size ();
196+
197+ // encode if at capacity
198+ if (batch.n_tokens + n_toks > n_batch) {
199+ float * out = emb + e * n_embd;
200+ batch_decode (ctx, batch, out, s, n_embd, params.embd_normalize );
201+ e += pooling_type == LLAMA_POOLING_TYPE_NONE ? batch.n_tokens : s;
202+ s = 0 ;
203+ common_batch_clear (batch);
204+ }
205+
206+ // add to batch
207+ batch_add_seq (batch, inp, s);
208+ s += 1 ;
209+ }
210+
211+ // final batch
212+ float * out = emb + e * n_embd;
213+ batch_decode (ctx, batch, out, s, n_embd, params.embd_normalize );
214+
215+ if (params.embd_out .empty ()) {
216+ LOG (" \n " );
217+
218+ if (pooling_type == LLAMA_POOLING_TYPE_NONE) {
219+ for (int j = 0 ; j < n_embd_count; j++) {
220+ LOG (" embedding %d: " , j);
221+ for (int i = 0 ; i < std::min (3 , n_embd); i++) {
222+ if (params.embd_normalize == 0 ) {
223+ LOG (" %6.0f " , emb[j * n_embd + i]);
224+ } else {
225+ LOG (" %9.6f " , emb[j * n_embd + i]);
226+ }
227+ }
228+ LOG (" ... " );
229+ for (int i = n_embd - 3 ; i < n_embd; i++) {
230+ if (params.embd_normalize == 0 ) {
231+ LOG (" %6.0f " , emb[j * n_embd + i]);
232+ } else {
233+ LOG (" %9.6f " , emb[j * n_embd + i]);
234+ }
235+ }
236+ LOG (" \n " );
237+ }
238+ } else if (pooling_type == LLAMA_POOLING_TYPE_RANK) {
239+ const uint32_t n_cls_out = llama_model_n_cls_out (model);
240+ std::vector<std::string> cls_out_labels;
241+
242+ for (uint32_t i = 0 ; i < n_cls_out; i++) {
243+ const char * label = llama_model_cls_label (model, i);
244+ const std::string label_i (label == nullptr ? " " : label);
245+ cls_out_labels.emplace_back (label_i.empty () ? std::to_string (i) : label_i);
246+ }
247+
248+ for (int j = 0 ; j < n_embd_count; j++) {
249+ for (uint32_t i = 0 ; i < n_cls_out; i++) {
250+ // NOTE: if you change this log - update the tests in ci/run.sh
251+ if (n_cls_out == 1 ) {
252+ LOG (" rerank score %d: %8.3f\n " , j, emb[j * n_embd]);
253+ } else {
254+ LOG (" rerank score %d: %8.3f [%s]\n " , j, emb[j * n_embd + i], cls_out_labels[i].c_str ());
255+ }
256+ }
257+ }
258+ } else {
259+ // print the first part of the embeddings or for a single prompt, the full embedding
260+ for (int j = 0 ; j < n_prompts; j++) {
261+ LOG (" embedding %d: " , j);
262+ for (int i = 0 ; i < (n_prompts > 1 ? std::min (16 , n_embd) : n_embd); i++) {
263+ if (params.embd_normalize == 0 ) {
264+ LOG (" %6.0f " , emb[j * n_embd + i]);
265+ } else {
266+ LOG (" %9.6f " , emb[j * n_embd + i]);
267+ }
268+ }
269+ LOG (" \n " );
270+ }
271+
272+ // print cosine similarity matrix
273+ if (n_prompts > 1 ) {
274+ LOG (" \n " );
275+ LOG (" cosine similarity matrix:\n\n " );
276+ for (int i = 0 ; i < n_prompts; i++) {
277+ LOG (" %6.6s " , prompts[i].c_str ());
278+ }
279+ LOG (" \n " );
280+ for (int i = 0 ; i < n_prompts; i++) {
281+ for (int j = 0 ; j < n_prompts; j++) {
282+ float sim = common_embd_similarity_cos (emb + i * n_embd, emb + j * n_embd, n_embd);
283+ LOG (" %6.2f " , sim);
284+ }
285+ LOG (" %1.10s" , prompts[i].c_str ());
286+ LOG (" \n " );
287+ }
288+ }
289+ }
290+ }
291+
292+ if (params.embd_out == " json" || params.embd_out == " json+" || params.embd_out == " array" ) {
293+ const bool notArray = params.embd_out != " array" ;
294+
295+ LOG (notArray ? " {\n \" object\" : \" list\" ,\n \" data\" : [\n " : " [" );
296+ for (int j = 0 ;;) { // at least one iteration (one prompt)
297+ if (notArray) LOG (" {\n \" object\" : \" embedding\" ,\n \" index\" : %d,\n \" embedding\" : " ,j);
298+ LOG (" [" );
299+ for (int i = 0 ;;) { // at least one iteration (n_embd > 0)
300+ LOG (params.embd_normalize == 0 ? " %1.0f" : " %1.7f" , emb[j * n_embd + i]);
301+ i++;
302+ if (i < n_embd) LOG (" ," ); else break ;
303+ }
304+ LOG (notArray ? " ]\n }" : " ]" );
305+ j++;
306+ if (j < n_embd_count) LOG (notArray ? " ,\n " : " ," ); else break ;
307+ }
308+ LOG (notArray ? " \n ]" : " ]\n " );
309+
310+ if (params.embd_out == " json+" && n_prompts > 1 ) {
311+ LOG (" ,\n \" cosineSimilarity\" : [\n " );
312+ for (int i = 0 ;;) { // at least two iteration (n_embd_count > 1)
313+ LOG (" [" );
314+ for (int j = 0 ;;) { // at least two iteration (n_embd_count > 1)
315+ float sim = common_embd_similarity_cos (emb + i * n_embd, emb + j * n_embd, n_embd);
316+ LOG (" %6.2f" , sim);
317+ j++;
318+ if (j < n_embd_count) LOG (" , " ); else break ;
319+ }
320+ LOG (" ]" );
321+ i++;
322+ if (i < n_embd_count) LOG (" ,\n " ); else break ;
323+ }
324+ LOG (" \n ]" );
325+ }
326+
327+ if (notArray) LOG (" \n }\n " );
328+ }
329+
330+ LOG (" \n " );
331+ llama_perf_context_print (ctx);
332+
333+ // clean up
334+ llama_batch_free (batch);
335+ llama_backend_free ();
336+
337+ return 0 ;
338+ }
0 commit comments