1- import json
21import logging
3- import time
4- import uuid
52import app .api .globals as cms_globals
63
7- from typing import Union , Iterable , AsyncGenerator
84from typing_extensions import Annotated
95from fastapi import APIRouter , Depends , Request , Body , Query
10- from fastapi .encoders import jsonable_encoder
11- from fastapi .responses import PlainTextResponse , StreamingResponse , JSONResponse
12- from starlette .status import HTTP_200_OK , HTTP_400_BAD_REQUEST
13- from app .domain import Tags , OpenAIChatRequest , OpenAIChatResponse , PromptMessage , PromptRole
6+ from fastapi .responses import PlainTextResponse , StreamingResponse
7+ from app .domain import Tags
148from app .model_services .base import AbstractModelService
15- from app .utils import get_settings , get_prompt_from_messages
9+ from app .utils import get_settings
1610from app .api .utils import get_rate_limiter
17- from app .api .dependencies import validate_tracking_id
1811
1912PATH_GENERATE = "/generate"
2013PATH_GENERATE_ASYNC = "/stream/generate"
21- PATH_OPENAI_COMPLETIONS = "/v1/chat/completions"
2214
2315router = APIRouter ()
2416config = get_settings ()
@@ -39,8 +31,6 @@ def generate_text(
3931 request : Request ,
4032 prompt : Annotated [str , Body (description = "The prompt to be sent to the model" , media_type = "text/plain" )],
4133 max_tokens : Annotated [int , Query (description = "The maximum number of tokens to generate" , gt = 0 )] = 512 ,
42- temperature : Annotated [float , Query (description = "The temperature of the generated text" , gt = 0.0 , lt = 1.0 )] = 0.7 ,
43- tracking_id : Union [str , None ] = Depends (validate_tracking_id ),
4434 model_service : AbstractModelService = Depends (cms_globals .model_service_dep )
4535) -> PlainTextResponse :
4636 """
@@ -50,27 +40,13 @@ def generate_text(
5040 request (Request): The request object.
5141 prompt (str): The prompt to be sent to the model.
5242 max_tokens (int): The maximum number of tokens to generate.
53- temperature (float): The temperature of the generated text.
54- tracking_id (Union[str, None]): An optional tracking ID of the requested task.
5543 model_service (AbstractModelService): The model service dependency.
5644
5745 Returns:
5846 PlainTextResponse: A response containing the generated text.
5947 """
6048
61- tracking_id = tracking_id or str (uuid .uuid4 ())
62- if prompt :
63- return PlainTextResponse (
64- model_service .generate (prompt , max_tokens = max_tokens , temperature = temperature ),
65- headers = {"x-cms-tracking-id" : tracking_id },
66- status_code = HTTP_200_OK ,
67- )
68- else :
69- return PlainTextResponse (
70- _empty_prompt_error (),
71- headers = {"x-cms-tracking-id" : tracking_id },
72- status_code = HTTP_400_BAD_REQUEST ,
73- )
49+ return PlainTextResponse (model_service .generate (prompt , max_tokens = max_tokens ))
7450
7551
7652@router .post (
@@ -84,8 +60,6 @@ async def generate_text_stream(
8460 request : Request ,
8561 prompt : Annotated [str , Body (description = "The prompt to be sent to the model" , media_type = "text/plain" )],
8662 max_tokens : Annotated [int , Query (description = "The maximum number of tokens to generate" , gt = 0 )] = 512 ,
87- temperature : Annotated [float , Query (description = "The temperature of the generated text" , gt = 0.0 , lt = 1.0 )] = 0.7 ,
88- tracking_id : Union [str , None ] = Depends (validate_tracking_id ),
8963 model_service : AbstractModelService = Depends (cms_globals .model_service_dep )
9064) -> StreamingResponse :
9165 """
@@ -95,121 +69,13 @@ async def generate_text_stream(
9569 request (Request): The request object.
9670 prompt (str): The prompt to be sent to the model.
9771 max_tokens (int): The maximum number of tokens to generate.
98- temperature (float): The temperature of the generated text.
99- tracking_id (Union[str, None]): An optional tracking ID of the requested task.
10072 model_service (AbstractModelService): The model service dependency.
10173
10274 Returns:
10375 StreamingResponse: A streaming response containing the text generated in near real-time.
10476 """
10577
106- tracking_id = tracking_id or str (uuid .uuid4 ())
107- if prompt :
108- return StreamingResponse (
109- model_service .generate_async (prompt , max_tokens = max_tokens , temperature = temperature ),
110- media_type = "text/event-stream" ,
111- headers = {"x-cms-tracking-id" : tracking_id },
112- status_code = HTTP_200_OK ,
113- )
114- else :
115- return StreamingResponse (
116- _empty_prompt_error (),
117- media_type = "text/event-stream" ,
118- headers = {"x-cms-tracking-id" : tracking_id },
119- status_code = HTTP_400_BAD_REQUEST ,
120- )
121-
122-
123- @router .post (
124- "/v1/chat/completions" ,
125- tags = [Tags .Generative .name ],
126- response_model = None ,
127- dependencies = [Depends (cms_globals .props .current_active_user )],
128- description = "Generate chat response based on messages, similar to OpenAI's /v1/chat/completions" ,
129- )
130- def generate_chat_completions (
131- request : Request ,
132- request_data : Annotated [OpenAIChatRequest , Body (
133- description = "OpenAI-like completion request" , media_type = "application/json"
134- )],
135- tracking_id : Union [str , None ] = Depends (validate_tracking_id ),
136- model_service : AbstractModelService = Depends (cms_globals .model_service_dep )
137- ) -> Union [StreamingResponse , JSONResponse ]:
138- """
139- Generate chat response based on messages, mimicking OpenAI's /v1/chat/completions endpoint.
140-
141- Args:
142- request (Request): The request object.
143- request_data (OpenAIChatRequest): The request data containing model, messages, and stream.
144- tracking_id (Union[str, None]): An optional tracking ID of the requested task.
145- model_service (AbstractModelService): The model service dependency.
146-
147- Returns:
148- StreamingResponse: A OpenAI-like response containing the text generated in near real-time.
149- JSONResponse: A response containing an error message if the prompt messages are empty.
150- """
151-
152- messages = request_data .messages
153- stream = request_data .stream
154- max_tokens = request_data .max_tokens
155- temperature = request_data .temperature
156-
157- if not messages :
158- error_response = {
159- "error" : {
160- "message" : "No prompt messages provided" ,
161- "type" : "invalid_request_error" ,
162- "param" : "messages" ,
163- "code" : "missing_field" ,
164- }
165- }
166- return JSONResponse (content = error_response , status_code = HTTP_400_BAD_REQUEST )
167-
168- async def _stream (p : str , mt : int , t : float ) -> AsyncGenerator :
169- data = {
170- "id" : tracking_id or str (uuid .uuid4 ()),
171- "object" : "chat.completion.chunk" ,
172- "choices" : [{"delta" : {"role" : PromptRole .ASSISTANT .value }}],
173- }
174- yield f"data: { json .dumps (data )} \n \n "
175- async for chunk in model_service .generate_async (p , max_tokens = mt , temperature = t ):
176- data = {
177- "choices" : [
178- {
179- "delta" : {"content" : chunk }
180- }
181- ],
182- "object" : "chat.completion.chunk" ,
183- }
184- yield f"data: { json .dumps (data )} \n \n "
185- yield "data: [DONE]\n \n "
186-
187- prompt = get_prompt_from_messages (model_service .tokenizer , messages ) # type: ignore
188- if stream :
189- return StreamingResponse (
190- _stream (prompt , max_tokens , temperature ),
191- media_type = "text/event-stream"
192- )
193- else :
194- generated_text = model_service .generate (prompt , max_tokens = max_tokens , temperature = temperature )
195- completion = OpenAIChatResponse (
196- id = str (uuid .uuid4 ()),
197- object = "chat.completion" ,
198- created = int (time .time ()),
199- model = model_service .model_name ,
200- choices = [
201- {
202- "index" : 0 ,
203- "message" : PromptMessage (
204- role = PromptRole .ASSISTANT ,
205- content = generated_text ,
206- ),
207- "finish_reason" : "stop" ,
208- }
209- ]
210- )
211- return JSONResponse (content = jsonable_encoder (completion ))
212-
213-
214- def _empty_prompt_error () -> Iterable [str ]:
215- yield "ERROR: No prompt text provided\n "
78+ return StreamingResponse (
79+ model_service .generate_async (prompt , max_tokens = max_tokens ),
80+ media_type = "text/event-stream"
81+ )
0 commit comments