66
77from typing import Union , Iterable , AsyncGenerator
88from typing_extensions import Annotated
9+ from functools import partial
910from fastapi import APIRouter , Depends , Request , Body , Query
1011from fastapi .encoders import jsonable_encoder
1112from fastapi .responses import PlainTextResponse , StreamingResponse , JSONResponse
1516from app .utils import get_settings , get_prompt_from_messages
1617from app .api .utils import get_rate_limiter
1718from app .api .dependencies import validate_tracking_id
19+ from app .management .prometheus_metrics import cms_prompt_tokens , cms_completion_tokens , cms_total_tokens
1820
1921PATH_GENERATE = "/generate"
2022PATH_GENERATE_ASYNC = "/stream/generate"
@@ -44,7 +46,7 @@ def generate_text(
4446 model_service : AbstractModelService = Depends (cms_globals .model_service_dep )
4547) -> PlainTextResponse :
4648 """
47- Generate text based on the prompt provided.
49+ Generates text based on the prompt provided.
4850
4951 Args:
5052 request (Request): The request object.
@@ -61,7 +63,12 @@ def generate_text(
6163 tracking_id = tracking_id or str (uuid .uuid4 ())
6264 if prompt :
6365 return PlainTextResponse (
64- model_service .generate (prompt , max_tokens = max_tokens , temperature = temperature ),
66+ model_service .generate (
67+ prompt ,
68+ max_tokens = max_tokens ,
69+ temperature = temperature ,
70+ report_tokens = partial (_send_usage_metrics , handler = PATH_GENERATE ),
71+ ),
6572 headers = {"x-cms-tracking-id" : tracking_id },
6673 status_code = HTTP_200_OK ,
6774 )
@@ -89,7 +96,7 @@ async def generate_text_stream(
8996 model_service : AbstractModelService = Depends (cms_globals .model_service_dep )
9097) -> StreamingResponse :
9198 """
92- Generate a stream of texts in near real-time.
99+ Generates a stream of texts in near real-time.
93100
94101 Args:
95102 request (Request): The request object.
@@ -106,7 +113,12 @@ async def generate_text_stream(
106113 tracking_id = tracking_id or str (uuid .uuid4 ())
107114 if prompt :
108115 return StreamingResponse (
109- model_service .generate_async (prompt , max_tokens = max_tokens , temperature = temperature ),
116+ model_service .generate_async (
117+ prompt ,
118+ max_tokens = max_tokens ,
119+ temperature = temperature ,
120+ report_tokens = partial (_send_usage_metrics , handler = PATH_GENERATE_ASYNC ),
121+ ),
110122 media_type = "text/event-stream" ,
111123 headers = {"x-cms-tracking-id" : tracking_id },
112124 status_code = HTTP_200_OK ,
@@ -121,7 +133,7 @@ async def generate_text_stream(
121133
122134
123135@router .post (
124- "/v1/chat/completions" ,
136+ PATH_OPENAI_COMPLETIONS ,
125137 tags = [Tags .Generative .name ],
126138 response_model = None ,
127139 dependencies = [Depends (cms_globals .props .current_active_user )],
@@ -136,7 +148,7 @@ def generate_chat_completions(
136148 model_service : AbstractModelService = Depends (cms_globals .model_service_dep )
137149) -> Union [StreamingResponse , JSONResponse ]:
138150 """
139- Generate chat response based on messages, mimicking OpenAI's /v1/chat/completions endpoint.
151+ Generates chat response based on messages, mimicking OpenAI's /v1/chat/completions endpoint.
140152
141153 Args:
142154 request (Request): The request object.
@@ -153,6 +165,7 @@ def generate_chat_completions(
153165 stream = request_data .stream
154166 max_tokens = request_data .max_tokens
155167 temperature = request_data .temperature
168+ tracking_id = tracking_id or str (uuid .uuid4 ())
156169
157170 if not messages :
158171 error_response = {
@@ -163,16 +176,25 @@ def generate_chat_completions(
163176 "code" : "missing_field" ,
164177 }
165178 }
166- return JSONResponse (content = error_response , status_code = HTTP_400_BAD_REQUEST )
179+ return JSONResponse (
180+ content = error_response ,
181+ status_code = HTTP_400_BAD_REQUEST ,
182+ headers = {"x-cms-tracking-id" : tracking_id },
183+ )
167184
168- async def _stream (p : str , mt : int , t : float ) -> AsyncGenerator :
185+ async def _stream (prompt : str , max_tokens : int , temperature : float ) -> AsyncGenerator :
169186 data = {
170- "id" : tracking_id or str ( uuid . uuid4 ()) ,
187+ "id" : tracking_id ,
171188 "object" : "chat.completion.chunk" ,
172189 "choices" : [{"delta" : {"role" : PromptRole .ASSISTANT .value }}],
173190 }
174191 yield f"data: { json .dumps (data )} \n \n "
175- async for chunk in model_service .generate_async (p , max_tokens = mt , temperature = t ):
192+ async for chunk in model_service .generate_async (
193+ prompt ,
194+ max_tokens = max_tokens ,
195+ temperature = temperature ,
196+ report_tokens = partial (_send_usage_metrics , handler = PATH_OPENAI_COMPLETIONS )
197+ ):
176198 data = {
177199 "choices" : [
178200 {
@@ -188,12 +210,18 @@ async def _stream(p: str, mt: int, t: float) -> AsyncGenerator:
188210 if stream :
189211 return StreamingResponse (
190212 _stream (prompt , max_tokens , temperature ),
191- media_type = "text/event-stream"
213+ media_type = "text/event-stream" ,
214+ headers = {"x-cms-tracking-id" : tracking_id },
192215 )
193216 else :
194- generated_text = model_service .generate (prompt , max_tokens = max_tokens , temperature = temperature )
217+ generated_text = model_service .generate (
218+ prompt ,
219+ max_tokens = max_tokens ,
220+ temperature = temperature ,
221+ send_metrics = partial (_send_usage_metrics , handler = PATH_OPENAI_COMPLETIONS ),
222+ )
195223 completion = OpenAIChatResponse (
196- id = str ( uuid . uuid4 ()) ,
224+ id = tracking_id ,
197225 object = "chat.completion" ,
198226 created = int (time .time ()),
199227 model = model_service .model_name ,
@@ -206,10 +234,19 @@ async def _stream(p: str, mt: int, t: float) -> AsyncGenerator:
206234 ),
207235 "finish_reason" : "stop" ,
208236 }
209- ]
237+ ],
210238 )
211- return JSONResponse (content = jsonable_encoder (completion ))
239+ return JSONResponse (content = jsonable_encoder (completion ), headers = { "x-cms-tracking-id" : tracking_id } )
212240
213241
214242def _empty_prompt_error () -> Iterable [str ]:
215243 yield "ERROR: No prompt text provided\n "
244+
245+
246+ def _send_usage_metrics (handler : str , prompt_token_num : int , completion_token_num : int ) -> None :
247+ cms_prompt_tokens .labels (handler = handler ).observe (prompt_token_num )
248+ logger .debug (f"Sent prompt tokens usage: { prompt_token_num } " )
249+ cms_completion_tokens .labels (handler = handler ).observe (completion_token_num )
250+ logger .debug (f"Sent completion tokens usage: { completion_token_num } " )
251+ cms_total_tokens .labels (handler = handler ).observe (prompt_token_num + completion_token_num )
252+ logger .debug (f"Sent total tokens usage: { prompt_token_num + completion_token_num } " )
0 commit comments