@@ -31,25 +31,47 @@ static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & toke
3131}
3232
3333static void batch_decode (llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd, int embd_norm) {
34+ const enum llama_pooling_type pooling_type = llama_pooling_type (ctx);
35+ const struct llama_model * model = llama_get_model (ctx);
36+
3437 // clear previous kv_cache values (irrelevant for embeddings)
3538 llama_kv_cache_clear (ctx);
3639
3740 // run model
3841 fprintf (stderr, " %s: n_tokens = %d, n_seq = %d\n " , __func__, batch.n_tokens , n_seq);
39- if (llama_decode (ctx, batch) < 0 ) {
40- fprintf (stderr, " %s : failed to decode\n " , __func__);
42+ if (llama_model_has_encoder (model) && !llama_model_has_decoder (model)) {
43+ // encoder-only model
44+ if (llama_encode (ctx, batch) < 0 ) {
45+ fprintf (stderr, " %s : failed to encode\n " , __func__);
46+ }
47+ } else if (!llama_model_has_encoder (model) && llama_model_has_decoder (model)) {
48+ // decoder-only model
49+ if (llama_decode (ctx, batch) < 0 ) {
50+ fprintf (stderr, " %s : failed to decode\n " , __func__);
51+ }
4152 }
4253
4354 for (int i = 0 ; i < batch.n_tokens ; i++) {
4455 if (!batch.logits [i]) {
4556 continue ;
4657 }
4758
48- // try to get sequence embeddings - supported only when pooling_type is not NONE
49- const float * embd = llama_get_embeddings_seq (ctx, batch.seq_id [i][0 ]);
50- GGML_ASSERT (embd != NULL && " failed to get sequence embeddings" );
59+ const float * embd = nullptr ;
60+ int embd_pos = 0 ;
61+
62+ if (pooling_type == LLAMA_POOLING_TYPE_NONE) {
63+ // try to get token embeddings
64+ embd = llama_get_embeddings_ith (ctx, i);
65+ embd_pos = i;
66+ GGML_ASSERT (embd != NULL && " failed to get token embeddings" );
67+ } else {
68+ // try to get sequence embeddings - supported only when pooling_type is not NONE
69+ embd = llama_get_embeddings_seq (ctx, batch.seq_id [i][0 ]);
70+ embd_pos = batch.seq_id [i][0 ];
71+ GGML_ASSERT (embd != NULL && " failed to get sequence embeddings" );
72+ }
5173
52- float * out = output + batch. seq_id [i][ 0 ] * n_embd;
74+ float * out = output + embd_pos * n_embd;
5375 llama_embd_normalize (embd, out, n_embd, embd_norm);
5476 }
5577}
@@ -93,8 +115,9 @@ int main(int argc, char ** argv) {
93115 const int n_ctx = llama_n_ctx (ctx);
94116
95117 const enum llama_pooling_type pooling_type = llama_pooling_type (ctx);
96- if (pooling_type == LLAMA_POOLING_TYPE_NONE) {
97- fprintf (stderr, " %s: error: pooling type NONE not supported\n " , __func__);
118+
119+ if (llama_model_has_encoder (model) && llama_model_has_decoder (model)) {
120+ fprintf (stderr, " %s: error: computing embeddings in encoder-decoder models is not supported\n " , __func__);
98121 return 1 ;
99122 }
100123
@@ -153,13 +176,23 @@ int main(int argc, char ** argv) {
153176 const int n_prompts = prompts.size ();
154177 struct llama_batch batch = llama_batch_init (n_batch, 0 , 1 );
155178
179+ // count number of embeddings
180+ int n_embd_count = 0 ;
181+ if (pooling_type == LLAMA_POOLING_TYPE_NONE) {
182+ for (int k = 0 ; k < n_prompts; k++) {
183+ n_embd_count += inputs[k].size ();
184+ }
185+ } else {
186+ n_embd_count = n_prompts;
187+ }
188+
156189 // allocate output
157190 const int n_embd = llama_n_embd (model);
158- std::vector<float > embeddings (n_prompts * n_embd, 0 );
191+ std::vector<float > embeddings (n_embd_count * n_embd, 0 );
159192 float * emb = embeddings.data ();
160193
161194 // break into batches
162- int p = 0 ; // number of prompts processed already
195+ int e = 0 ; // number of embeddings already stored
163196 int s = 0 ; // number of prompts in current batch
164197 for (int k = 0 ; k < n_prompts; k++) {
165198 // clamp to n_batch tokens
@@ -169,11 +202,11 @@ int main(int argc, char ** argv) {
169202
170203 // encode if at capacity
171204 if (batch.n_tokens + n_toks > n_batch) {
172- float * out = emb + p * n_embd;
205+ float * out = emb + e * n_embd;
173206 batch_decode (ctx, batch, out, s, n_embd, params.embd_normalize );
174- llama_batch_clear (batch);
175- p += s;
207+ e += pooling_type == LLAMA_POOLING_TYPE_NONE ? batch.n_tokens : s;
176208 s = 0 ;
209+ llama_batch_clear (batch);
177210 }
178211
179212 // add to batch
@@ -182,39 +215,62 @@ int main(int argc, char ** argv) {
182215 }
183216
184217 // final batch
185- float * out = emb + p * n_embd;
218+ float * out = emb + e * n_embd;
186219 batch_decode (ctx, batch, out, s, n_embd, params.embd_normalize );
187220
188221 if (params.embd_out .empty ()) {
189- // print the first part of the embeddings or for a single prompt, the full embedding
190222 fprintf (stdout, " \n " );
191- for (int j = 0 ; j < n_prompts; j++) {
192- fprintf (stdout, " embedding %d: " , j);
193- for (int i = 0 ; i < (n_prompts > 1 ? std::min (16 , n_embd) : n_embd); i++) {
194- if (params.embd_normalize == 0 ) {
195- fprintf (stdout, " %6.0f " , emb[j * n_embd + i]);
196- } else {
197- fprintf (stdout, " %9.6f " , emb[j * n_embd + i]);
223+
224+ if (pooling_type == LLAMA_POOLING_TYPE_NONE) {
225+ for (int j = 0 ; j < n_embd_count; j++) {
226+ fprintf (stdout, " embedding %d: " , j);
227+ for (int i = 0 ; i < std::min (3 , n_embd); i++) {
228+ if (params.embd_normalize == 0 ) {
229+ fprintf (stdout, " %6.0f " , emb[j * n_embd + i]);
230+ } else {
231+ fprintf (stdout, " %9.6f " , emb[j * n_embd + i]);
232+ }
233+ }
234+ fprintf (stdout, " ... " );
235+ for (int i = n_embd - 3 ; i < n_embd; i++) {
236+ if (params.embd_normalize == 0 ) {
237+ fprintf (stdout, " %6.0f " , emb[j * n_embd + i]);
238+ } else {
239+ fprintf (stdout, " %9.6f " , emb[j * n_embd + i]);
240+ }
198241 }
242+ fprintf (stdout, " \n " );
199243 }
200- fprintf (stdout, " \n " );
201- }
202-
203- // print cosine similarity matrix
204- if (n_prompts > 1 ) {
205- fprintf (stdout, " \n " );
206- printf (" cosine similarity matrix:\n\n " );
207- for (int i = 0 ; i < n_prompts; i++) {
208- fprintf (stdout, " %6.6s " , prompts[i].c_str ());
244+ } else {
245+ // print the first part of the embeddings or for a single prompt, the full embedding
246+ for (int j = 0 ; j < n_prompts; j++) {
247+ fprintf (stdout, " embedding %d: " , j);
248+ for (int i = 0 ; i < (n_prompts > 1 ? std::min (16 , n_embd) : n_embd); i++) {
249+ if (params.embd_normalize == 0 ) {
250+ fprintf (stdout, " %6.0f " , emb[j * n_embd + i]);
251+ } else {
252+ fprintf (stdout, " %9.6f " , emb[j * n_embd + i]);
253+ }
254+ }
255+ fprintf (stdout, " \n " );
209256 }
210- fprintf (stdout, " \n " );
211- for (int i = 0 ; i < n_prompts; i++) {
212- for (int j = 0 ; j < n_prompts; j++) {
213- float sim = llama_embd_similarity_cos (emb + i * n_embd, emb + j * n_embd, n_embd);
214- fprintf (stdout, " %6.2f " , sim);
257+
258+ // print cosine similarity matrix
259+ if (n_prompts > 1 ) {
260+ fprintf (stdout, " \n " );
261+ printf (" cosine similarity matrix:\n\n " );
262+ for (int i = 0 ; i < n_prompts; i++) {
263+ fprintf (stdout, " %6.6s " , prompts[i].c_str ());
215264 }
216- fprintf (stdout, " %1.10s" , prompts[i].c_str ());
217265 fprintf (stdout, " \n " );
266+ for (int i = 0 ; i < n_prompts; i++) {
267+ for (int j = 0 ; j < n_prompts; j++) {
268+ float sim = llama_embd_similarity_cos (emb + i * n_embd, emb + j * n_embd, n_embd);
269+ fprintf (stdout, " %6.2f " , sim);
270+ }
271+ fprintf (stdout, " %1.10s" , prompts[i].c_str ());
272+ fprintf (stdout, " \n " );
273+ }
218274 }
219275 }
220276 }
@@ -233,23 +289,23 @@ int main(int argc, char ** argv) {
233289 }
234290 fprintf (stdout, notArray ? " ]\n }" : " ]" );
235291 j++;
236- if (j < n_prompts ) fprintf (stdout, notArray ? " ,\n " : " ," ); else break ;
292+ if (j < n_embd_count ) fprintf (stdout, notArray ? " ,\n " : " ," ); else break ;
237293 }
238294 fprintf (stdout, notArray ? " \n ]" : " ]\n " );
239295
240296 if (params.embd_out == " json+" && n_prompts > 1 ) {
241297 fprintf (stdout, " ,\n \" cosineSimilarity\" : [\n " );
242- for (int i = 0 ;;) { // at least two iteration (n_prompts > 1)
298+ for (int i = 0 ;;) { // at least two iteration (n_embd_count > 1)
243299 fprintf (stdout, " [" );
244- for (int j = 0 ;;) { // at least two iteration (n_prompts > 1)
300+ for (int j = 0 ;;) { // at least two iteration (n_embd_count > 1)
245301 float sim = llama_embd_similarity_cos (emb + i * n_embd, emb + j * n_embd, n_embd);
246302 fprintf (stdout, " %6.2f" , sim);
247303 j++;
248- if (j < n_prompts ) fprintf (stdout, " , " ); else break ;
304+ if (j < n_embd_count ) fprintf (stdout, " , " ); else break ;
249305 }
250306 fprintf (stdout, " ]" );
251307 i++;
252- if (i < n_prompts ) fprintf (stdout, " ,\n " ); else break ;
308+ if (i < n_embd_count ) fprintf (stdout, " ,\n " ); else break ;
253309 }
254310 fprintf (stdout, " \n ]" );
255311 }
0 commit comments