@@ -1210,7 +1210,7 @@ struct llama_server_context
12101210 queue_results.send (res);
12111211 }
12121212
1213- void send_embedding (server_slot &slot)
1213+ void send_embedding (server_slot & slot, const llama_batch & batch )
12141214 {
12151215 task_result res;
12161216 res.id = slot.task_id ;
@@ -1219,6 +1219,7 @@ struct llama_server_context
12191219 res.stop = true ;
12201220
12211221 const int n_embd = llama_n_embd (model);
1222+
12221223 if (!params.embedding )
12231224 {
12241225 LOG_WARNING (" embedding disabled" , {{" params.embedding" , params.embedding }});
@@ -1229,12 +1230,29 @@ struct llama_server_context
12291230 }
12301231 else
12311232 {
1232- const float *data = llama_get_embeddings (ctx);
1233- std::vector<float > embedding (data, data + n_embd);
1234- res.result_json = json
1235- {
1236- {" embedding" , embedding},
1237- };
1233+ for (int i = 0 ; i < batch.n_tokens ; ++i) {
1234+ if (!batch.logits [i] || batch.seq_id [i][0 ] != slot.id ) {
1235+ continue ;
1236+ }
1237+
1238+ const float * embd = llama_get_embeddings_seq (ctx, batch.seq_id [i][0 ]);
1239+ if (embd == NULL ) {
1240+ embd = llama_get_embeddings_ith (ctx, i);
1241+ if (embd == NULL ) {
1242+ LOG_ERROR (" failed to get embeddings for token" , {{" token" , batch.token [i]}, {" seq_id" , batch.seq_id [i][0 ]}});
1243+ res.result_json = json
1244+ {
1245+ {" embedding" , std::vector<float >(n_embd, 0 .0f )},
1246+ };
1247+ continue ;
1248+ }
1249+ }
1250+
1251+ res.result_json = json
1252+ {
1253+ {" embedding" , std::vector<float >(embd, embd + n_embd)},
1254+ };
1255+ }
12381256 }
12391257 queue_results.send (res);
12401258 }
@@ -1845,7 +1863,7 @@ struct llama_server_context
18451863 ga_i += ga_w/ga_n;
18461864 }
18471865 }
1848- llama_batch_add (batch, prefix_tokens[slot.n_past ], system_tokens.size () + slot_npast, {slot.id }, false );
1866+ llama_batch_add (batch, prefix_tokens[slot.n_past ], system_tokens.size () + slot_npast, { slot.id }, false );
18491867 slot_npast++;
18501868 }
18511869
@@ -1881,7 +1899,7 @@ struct llama_server_context
18811899
18821900 for (int32_t i = 0 ; i < (int32_t ) batch.n_tokens ; i += n_batch)
18831901 {
1884- const int32_t n_tokens = std::min (n_batch, ( int32_t ) ( batch.n_tokens - i) );
1902+ const int32_t n_tokens = std::min (n_batch, batch.n_tokens - i);
18851903
18861904 for (auto & slot : slots)
18871905 {
@@ -1954,7 +1972,7 @@ struct llama_server_context
19541972 // prompt evaluated for embedding
19551973 if (slot.embedding )
19561974 {
1957- send_embedding (slot);
1975+ send_embedding (slot, batch_view );
19581976 slot.release ();
19591977 slot.i_batch = -1 ;
19601978 continue ;
@@ -2036,6 +2054,8 @@ static void server_print_usage(const char *argv0, const gpt_params ¶ms,
20362054 printf (" --yarn-attn-factor N YaRN: scale sqrt(t) or attention magnitude (default: 1.0)\n " );
20372055 printf (" --yarn-beta-slow N YaRN: high correction dim or alpha (default: %.1f)\n " , params.yarn_beta_slow );
20382056 printf (" --yarn-beta-fast N YaRN: low correction dim or beta (default: %.1f)\n " , params.yarn_beta_fast );
2057+ printf (" --pooling {none,mean,cls}\n " );
2058+ printf (" pooling type for embeddings, use model default if unspecified\n " );
20392059 printf (" -b N, --batch-size N batch size for prompt processing (default: %d)\n " , params.n_batch );
20402060 printf (" --memory-f32 use f32 instead of f16 for memory key+value (default: disabled)\n " );
20412061 printf (" not recommended: doubles context memory required and no measurable increase in quality\n " );
@@ -2276,6 +2296,18 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
22762296 }
22772297 params.yarn_beta_slow = std::stof (argv[i]);
22782298 }
2299+ else if (arg == " --pooling" )
2300+ {
2301+ if (++i >= argc) {
2302+ invalid_param = true ;
2303+ break ;
2304+ }
2305+ std::string value (argv[i]);
2306+ /* */ if (value == " none" ) { params.pooling_type = LLAMA_POOLING_TYPE_NONE; }
2307+ else if (value == " mean" ) { params.pooling_type = LLAMA_POOLING_TYPE_MEAN; }
2308+ else if (value == " cls" ) { params.pooling_type = LLAMA_POOLING_TYPE_CLS; }
2309+ else { invalid_param = true ; break ; }
2310+ }
22792311 else if (arg == " --threads" || arg == " -t" )
22802312 {
22812313 if (++i >= argc)
@@ -2330,7 +2362,6 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
23302362 break ;
23312363 }
23322364 params.n_batch = std::stoi (argv[i]);
2333- params.n_batch = std::min (512 , params.n_batch );
23342365 }
23352366 else if (arg == " --gpu-layers" || arg == " -ngl" || arg == " --n-gpu-layers" )
23362367 {
0 commit comments