@@ -2005,6 +2005,11 @@ struct server_context {
20052005 int32_t n_batch = llama_n_batch (ctx);
20062006 int32_t n_ubatch = llama_n_ubatch (ctx);
20072007
2008+ // track if this is an embedding or non-embedding batch
2009+ // if we've added sampled tokens above, we are in non-embedding mode
2010+ // -1: none, 0: non-embedding, 1: embedding
2011+ int32_t batch_type = batch.n_tokens > 0 ? 0 : -1 ;
2012+
20082013 // next, batch any pending prompts without exceeding n_batch
20092014 if (params.cont_batching || batch.n_tokens == 0 ) {
20102015 for (auto & slot : slots) {
@@ -2175,6 +2180,14 @@ struct server_context {
21752180 }
21762181 }
21772182
2183+ // check that we are in the right batch_type, if not defer the slot
2184+ bool slot_type = slot.embedding ? 1 : 0 ;
2185+ if (batch_type == -1 ) {
2186+ batch_type = slot_type;
2187+ } else if (batch_type != slot_type) {
2188+ continue ;
2189+ }
2190+
21782191 // keep only the common part
21792192 int p0 = (int ) system_tokens.size () + slot.n_past ;
21802193 if (!llama_kv_cache_seq_rm (ctx, slot.id + 1 , p0, -1 )) {
@@ -2276,6 +2289,9 @@ struct server_context {
22762289 {" n_tokens" , batch.n_tokens },
22772290 });
22782291
2292+ // make sure we're in the right embedding mode
2293+ llama_set_embeddings (ctx, batch_type == 1 );
2294+
22792295 // process the created batch of tokens
22802296 for (int32_t i = 0 ; i < batch.n_tokens ; i += n_batch) {
22812297 const int32_t n_tokens = std::min (n_batch, batch.n_tokens - i);
@@ -2990,6 +3006,11 @@ int main(int argc, char ** argv) {
29903006 };
29913007
29923008 const auto handle_completions = [&ctx_server, &res_error](const httplib::Request & req, httplib::Response & res) {
3009+ if (ctx_server.params .embedding ) {
3010+ res_error (res, format_error_response (" This server does not support completions. Start it without `--embeddings`" , ERROR_TYPE_NOT_SUPPORTED));
3011+ return ;
3012+ }
3013+
29933014 res.set_header (" Access-Control-Allow-Origin" , req.get_header_value (" Origin" ));
29943015
29953016 json data = json::parse (req.body );
@@ -3085,6 +3106,11 @@ int main(int argc, char ** argv) {
30853106 };
30863107
30873108 const auto handle_chat_completions = [&ctx_server, ¶ms, &res_error](const httplib::Request & req, httplib::Response & res) {
3109+ if (ctx_server.params .embedding ) {
3110+ res_error (res, format_error_response (" This server does not support chat completions. Start it without `--embeddings`" , ERROR_TYPE_NOT_SUPPORTED));
3111+ return ;
3112+ }
3113+
30883114 res.set_header (" Access-Control-Allow-Origin" , req.get_header_value (" Origin" ));
30893115 json data = oaicompat_completion_params_parse (ctx_server.model , json::parse (req.body ), params.chat_template );
30903116
@@ -3157,6 +3183,11 @@ int main(int argc, char ** argv) {
31573183 };
31583184
31593185 const auto handle_infill = [&ctx_server, &res_error](const httplib::Request & req, httplib::Response & res) {
3186+ if (ctx_server.params .embedding ) {
3187+ res_error (res, format_error_response (" This server does not support infill. Start it without `--embeddings`" , ERROR_TYPE_NOT_SUPPORTED));
3188+ return ;
3189+ }
3190+
31603191 res.set_header (" Access-Control-Allow-Origin" , req.get_header_value (" Origin" ));
31613192
31623193 json data = json::parse (req.body );
@@ -3243,13 +3274,8 @@ int main(int argc, char ** argv) {
32433274 return res.set_content (data.dump (), " application/json; charset=utf-8" );
32443275 };
32453276
3246- const auto handle_embeddings = [¶ms, & ctx_server, &res_error](const httplib::Request & req, httplib::Response & res) {
3277+ const auto handle_embeddings = [&ctx_server, &res_error](const httplib::Request & req, httplib::Response & res) {
32473278 res.set_header (" Access-Control-Allow-Origin" , req.get_header_value (" Origin" ));
3248- if (!params.embedding ) {
3249- res.status = 501 ;
3250- res.set_content (" This server does not support embeddings. Start it with `--embeddings`" , " text/plain; charset=utf-8" );
3251- return ;
3252- }
32533279
32543280 const json body = json::parse (req.body );
32553281 bool is_openai = false ;
0 commit comments