1010from fastapi import APIRouter , Depends , Request , Body , Query
1111from fastapi .encoders import jsonable_encoder
1212from fastapi .responses import PlainTextResponse , StreamingResponse , JSONResponse
13- from starlette .status import HTTP_200_OK , HTTP_400_BAD_REQUEST
14- from app .domain import Tags , OpenAIChatRequest , OpenAIChatResponse , PromptMessage , PromptRole
13+ from starlette .status import HTTP_200_OK , HTTP_400_BAD_REQUEST , HTTP_500_INTERNAL_SERVER_ERROR
14+ from app .domain import (
15+ Tags ,
16+ OpenAIChatRequest ,
17+ OpenAIChatResponse ,
18+ OpenAIEmbeddingsRequest ,
19+ OpenAIEmbeddingsResponse ,
20+ PromptMessage ,
21+ PromptRole ,
22+ )
1523from app .model_services .base import AbstractModelService
1624from app .utils import get_settings , get_prompt_from_messages
1725from app .api .utils import get_rate_limiter
2129PATH_GENERATE = "/generate"
2230PATH_GENERATE_ASYNC = "/stream/generate"
2331PATH_OPENAI_COMPLETIONS = "/v1/chat/completions"
32+ PATH_OPENAI_EMBEDDINGS = "/v1/embeddings"
2433
2534router = APIRouter ()
2635config = get_settings ()
@@ -134,7 +143,7 @@ async def generate_text_stream(
134143
135144@router .post (
136145 PATH_OPENAI_COMPLETIONS ,
137- tags = [Tags .Generative .name ],
146+ tags = [Tags .OpenAICompatible .name ],
138147 response_model = None ,
139148 dependencies = [Depends (cms_globals .props .current_active_user )],
140149 description = "Generate chat response based on messages, similar to OpenAI's /v1/chat/completions" ,
@@ -162,6 +171,7 @@ def generate_chat_completions(
162171 """
163172
164173 messages = request_data .messages
174+ model = model_service .model_name if request_data .model != model_service .model_name else request_data .model
165175 stream = request_data .stream
166176 max_tokens = request_data .max_tokens
167177 temperature = request_data .temperature
@@ -224,7 +234,7 @@ async def _stream(prompt: str, max_tokens: int, temperature: float) -> AsyncGene
224234 id = tracking_id ,
225235 object = "chat.completion" ,
226236 created = int (time .time ()),
227- model = model_service . model_name ,
237+ model = model ,
228238 choices = [
229239 {
230240 "index" : 0 ,
@@ -239,14 +249,100 @@ async def _stream(prompt: str, max_tokens: int, temperature: float) -> AsyncGene
239249 return JSONResponse (content = jsonable_encoder (completion ), headers = {"x-cms-tracking-id" : tracking_id })
240250
241251
252+ @router .post (
253+ PATH_OPENAI_EMBEDDINGS ,
254+ tags = [Tags .OpenAICompatible .name ],
255+ response_model = None ,
256+ dependencies = [Depends (cms_globals .props .current_active_user )],
257+ description = "Create embeddings based on text(s), similar to OpenAI's /v1/embeddings endpoint" ,
258+ )
259+ def embed_texts (
260+ request : Request ,
261+ request_data : Annotated [OpenAIEmbeddingsRequest , Body (
262+ description = "Text(s) to be embedded" , media_type = "application/json"
263+ )],
264+ tracking_id : Union [str , None ] = Depends (validate_tracking_id ),
265+ model_service : AbstractModelService = Depends (cms_globals .model_service_dep )
266+ ) -> JSONResponse :
267+ """
268+ Embeds text or a list of texts, mimicking OpenAI's /v1/embeddings endpoint.
269+
270+ Args:
271+ request (Request): The request object.
272+ request_data (OpenAIEmbeddingsRequest): The request data containing model and input text(s).
273+ tracking_id (Union[str, None]): An optional tracking ID of the requested task.
274+ model_service (AbstractModelService): The model service dependency.
275+
276+ Returns:
277+ JSONResponse: A response containing the embeddings of the text(s).
278+ """
279+ tracking_id = tracking_id or str (uuid .uuid4 ())
280+
281+ if not hasattr (model_service , "create_embeddings" ):
282+ error_response = {
283+ "error" : {
284+ "message" : "Model does not support embeddings" ,
285+ "type" : "invalid_request_error" ,
286+ "param" : "model" ,
287+ "code" : "model_not_supported" ,
288+ }
289+ }
290+ return JSONResponse (
291+ content = error_response ,
292+ status_code = HTTP_500_INTERNAL_SERVER_ERROR ,
293+ headers = {"x-cms-tracking-id" : tracking_id },
294+ )
295+
296+ input_text = request_data .input
297+ model = model_service .model_name if request_data .model != model_service .model_name else request_data .model
298+
299+ if isinstance (input_text , str ):
300+ input_texts = [input_text ]
301+ else :
302+ input_texts = input_text
303+
304+ try :
305+ embeddings_data = []
306+
307+ for i , embedding in enumerate (model_service .create_embeddings (input_texts )):
308+ embeddings_data .append ({
309+ "object" : "embedding" ,
310+ "embedding" : embedding ,
311+ "index" : i ,
312+ })
313+
314+ response = OpenAIEmbeddingsResponse (object = "list" , data = embeddings_data , model = model )
315+
316+ return JSONResponse (
317+ content = jsonable_encoder (response ),
318+ headers = {"x-cms-tracking-id" : tracking_id },
319+ )
320+
321+ except Exception as e :
322+ logger .error ("Failed to create embeddings" )
323+ logger .exception (e )
324+ error_response = {
325+ "error" : {
326+ "message" : f"Failed to create embeddings: { str (e )} " ,
327+ "type" : "server_error" ,
328+ "code" : "internal_error" ,
329+ }
330+ }
331+ return JSONResponse (
332+ content = error_response ,
333+ status_code = HTTP_500_INTERNAL_SERVER_ERROR ,
334+ headers = {"x-cms-tracking-id" : tracking_id },
335+ )
336+
337+
242338def _empty_prompt_error () -> Iterable [str ]:
243339 yield "ERROR: No prompt text provided\n "
244340
245341
246342def _send_usage_metrics (handler : str , prompt_token_num : int , completion_token_num : int ) -> None :
247343 cms_prompt_tokens .labels (handler = handler ).observe (prompt_token_num )
248- logger .debug (f "Sent prompt tokens usage: { prompt_token_num } " )
344+ logger .debug ("Sent prompt tokens usage: %s" , prompt_token_num )
249345 cms_completion_tokens .labels (handler = handler ).observe (completion_token_num )
250- logger .debug (f "Sent completion tokens usage: { completion_token_num } " )
346+ logger .debug ("Sent completion tokens usage: %s" , completion_token_num )
251347 cms_total_tokens .labels (handler = handler ).observe (prompt_token_num + completion_token_num )
252- logger .debug (f "Sent total tokens usage: { prompt_token_num + completion_token_num } " )
348+ logger .debug ("Sent total tokens usage: %s" , prompt_token_num + completion_token_num )
0 commit comments