@@ -37,6 +37,7 @@ struct EmbeddingParams
3737 bool parse_special;
3838 ctl::string_view prompt;
3939 ctl::string content;
40+ ctl::string model;
4041};
4142
4243void
@@ -78,10 +79,29 @@ Client::get_embedding_params(EmbeddingParams* params)
7879{
7980 params->add_special = atob (or_empty (param (" add_special" )), true );
8081 params->parse_special = atob (or_empty (param (" parse_special" )), false );
82+
83+ // try obtaining prompt (or its aliases) from request-uri
8184 ctl::optional<ctl::string_view> prompt = param (" content" );
85+ if (!prompt.has_value ()) {
86+ ctl::optional<ctl::string_view> prompt2 = param (" prompt" );
87+ if (prompt2.has_value ()) {
88+ prompt = ctl::move (prompt2);
89+ } else {
90+ ctl::optional<ctl::string_view> prompt3 = param (" input" );
91+ if (prompt3.has_value ()) {
92+ prompt = ctl::move (prompt3);
93+ }
94+ }
95+ }
96+
8297 if (prompt.has_value ()) {
98+ // [simple mode] if the prompt was supplied in the request-uri
99+ // then we don't bother looking for a json body.
83100 params->prompt = prompt.value ();
84101 } else if (HasHeader (kHttpContentType )) {
102+ // [standard mode] if the prompt wasn't specified as a
103+ // request-uri parameter, then it must be in the
104+ // http message body.
85105 if (IsMimeType (HeaderData (kHttpContentType ),
86106 HeaderLength (kHttpContentType ),
87107 " text/plain" )) {
@@ -94,14 +114,21 @@ Client::get_embedding_params(EmbeddingParams* params)
94114 return send_error (400 , Json::StatusToString (json.first ));
95115 if (!json.second .isObject ())
96116 return send_error (400 , " JSON body must be an object" );
97- if (!json.second [" content" ].isString ())
98- return send_error (400 , " JSON missing \" content\" key" );
99- params->content = ctl::move (json.second [" content" ].getString ());
117+ if (json.second [" content" ].isString ())
118+ params->content = ctl::move (json.second [" content" ].getString ());
119+ else if (json.second [" prompt" ].isString ())
120+ params->content = ctl::move (json.second [" prompt" ].getString ());
121+ else if (json.second [" input" ].isString ())
122+ params->content = ctl::move (json.second [" input" ].getString ());
123+ else
124+ return send_error (400 , " JSON missing content/prompt/input key" );
100125 params->prompt = params->content ;
101126 if (json.second [" add_special" ].isBool ())
102127 params->add_special = json.second [" add_special" ].getBool ();
103128 if (json.second [" parse_special" ].isBool ())
104129 params->parse_special = json.second [" parse_special" ].getBool ();
130+ if (json.second [" model" ].isString ())
131+ params->model = ctl::move (json.second [" model" ].getString ());
105132 } else {
106133 return send_error (501 , " Content Type Not Implemented" );
107134 }
@@ -207,21 +234,68 @@ Client::embedding()
207234 embd, embeddings->data () + batch->seq_id [i][0 ] * n_embd, n_embd);
208235 }
209236
237+ // determine how output json should look
238+ bool in_openai_mode = path () == " /v1/embeddings" ;
239+
210240 // serialize tokens to json
211241 char * p = obuf.p ;
212242 p = stpcpy (p, " {\n " );
213- p = stpcpy (p, " \" add_special\" : " );
214- p = encode_bool (p, params->add_special );
215- p = stpcpy (p, " ,\n " );
216- p = stpcpy (p, " \" parse_special\" : " );
217- p = encode_bool (p, params->parse_special );
218- p = stpcpy (p, " ,\n " );
219- p = stpcpy (p, " \" tokens_provided\" : " );
220- p = encode_json (p, toks->size ());
221- p = stpcpy (p, " ,\n " );
222- p = stpcpy (p, " \" tokens_used\" : " );
223- p = encode_json (p, count);
224- p = stpcpy (p, " ,\n " );
243+
244+ // Here's what an OpenAI /v1/embedding response looks like:
245+ //
246+ // {
247+ // "object": "list",
248+ // "data": [
249+ // {
250+ // "object": "embedding",
251+ // "index": 0,
252+ // "embedding": [
253+ // -0.006929283495992422,
254+ // -0.005336422007530928,
255+ // ... (omitted for spacing)
256+ // -4.547132266452536e-05,
257+ // -0.024047505110502243
258+ // ],
259+ // }
260+ // ],
261+ // "model": "text-embedding-3-small",
262+ // "usage": {
263+ // "prompt_tokens": 5,
264+ // "total_tokens": 5
265+ // }
266+ // }
267+ //
268+
269+ if (in_openai_mode) {
270+ p = stpcpy (p, " \" object\" : \" list\" ,\n " );
271+ p = stpcpy (p, " \" model\" : " );
272+ p = encode_json (p, params->model );
273+ p = stpcpy (p, " ,\n " );
274+ p = stpcpy (p, " \" usage\" : {\n " );
275+ p = stpcpy (p, " \" prompt_tokens\" : " );
276+ p = encode_json (p, count);
277+ p = stpcpy (p, " ,\n " );
278+ p = stpcpy (p, " \" total_tokens\" : " );
279+ p = encode_json (p, toks->size ());
280+ p = stpcpy (p, " \n },\n " );
281+ p = stpcpy (p, " \" data\" : [{\n " );
282+ p = stpcpy (p, " \" object\" : \" embedding\" ,\n " );
283+ p = stpcpy (p, " \" index\" : 0,\n " );
284+ } else {
285+ p = stpcpy (p, " \" add_special\" : " );
286+ p = encode_bool (p, params->add_special );
287+ p = stpcpy (p, " ,\n " );
288+ p = stpcpy (p, " \" parse_special\" : " );
289+ p = encode_bool (p, params->parse_special );
290+ p = stpcpy (p, " ,\n " );
291+ p = stpcpy (p, " \" tokens_provided\" : " );
292+ p = encode_json (p, toks->size ());
293+ p = stpcpy (p, " ,\n " );
294+ p = stpcpy (p, " \" tokens_used\" : " );
295+ p = encode_json (p, count);
296+ p = stpcpy (p, " ,\n " );
297+ }
298+
225299 p = stpcpy (p, " \" embedding\" : [" );
226300 for (size_t i = 0 ; i < embeddings->size (); ++i) {
227301 if (i) {
@@ -231,6 +305,8 @@ Client::embedding()
231305 p = encode_json (p, (*embeddings)[i]);
232306 }
233307 p = stpcpy (p, " ]\n " );
308+ if (in_openai_mode)
309+ p = stpcpy (p, " }]\n " );
234310 p = stpcpy (p, " }\n " );
235311 ctl::string_view content (obuf.p , p - obuf.p );
236312
0 commit comments