Skip to content

Commit dc1f076

Browse files
committed
feat: add metrics for usages of prompt and completion tokens
1 parent d1aefdd commit dc1f076

File tree

6 files changed

+114
-24
lines changed

6 files changed

+114
-24
lines changed

app/api/routers/generative.py

Lines changed: 52 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from typing import Union, Iterable, AsyncGenerator
88
from typing_extensions import Annotated
9+
from functools import partial
910
from fastapi import APIRouter, Depends, Request, Body, Query
1011
from fastapi.encoders import jsonable_encoder
1112
from fastapi.responses import PlainTextResponse, StreamingResponse, JSONResponse
@@ -15,6 +16,7 @@
1516
from app.utils import get_settings, get_prompt_from_messages
1617
from app.api.utils import get_rate_limiter
1718
from app.api.dependencies import validate_tracking_id
19+
from app.management.prometheus_metrics import cms_prompt_tokens, cms_completion_tokens, cms_total_tokens
1820

1921
PATH_GENERATE = "/generate"
2022
PATH_GENERATE_ASYNC = "/stream/generate"
@@ -44,7 +46,7 @@ def generate_text(
4446
model_service: AbstractModelService = Depends(cms_globals.model_service_dep)
4547
) -> PlainTextResponse:
4648
"""
47-
Generate text based on the prompt provided.
49+
Generates text based on the prompt provided.
4850
4951
Args:
5052
request (Request): The request object.
@@ -61,7 +63,12 @@ def generate_text(
6163
tracking_id = tracking_id or str(uuid.uuid4())
6264
if prompt:
6365
return PlainTextResponse(
64-
model_service.generate(prompt, max_tokens=max_tokens, temperature=temperature),
66+
model_service.generate(
67+
prompt,
68+
max_tokens=max_tokens,
69+
temperature=temperature,
70+
report_tokens=partial(_send_usage_metrics, handler=PATH_GENERATE),
71+
),
6572
headers={"x-cms-tracking-id": tracking_id},
6673
status_code=HTTP_200_OK,
6774
)
@@ -89,7 +96,7 @@ async def generate_text_stream(
8996
model_service: AbstractModelService = Depends(cms_globals.model_service_dep)
9097
) -> StreamingResponse:
9198
"""
92-
Generate a stream of texts in near real-time.
99+
Generates a stream of texts in near real-time.
93100
94101
Args:
95102
request (Request): The request object.
@@ -106,7 +113,12 @@ async def generate_text_stream(
106113
tracking_id = tracking_id or str(uuid.uuid4())
107114
if prompt:
108115
return StreamingResponse(
109-
model_service.generate_async(prompt, max_tokens=max_tokens, temperature=temperature),
116+
model_service.generate_async(
117+
prompt,
118+
max_tokens=max_tokens,
119+
temperature=temperature,
120+
report_tokens=partial(_send_usage_metrics, handler=PATH_GENERATE_ASYNC),
121+
),
110122
media_type="text/event-stream",
111123
headers={"x-cms-tracking-id": tracking_id},
112124
status_code=HTTP_200_OK,
@@ -121,7 +133,7 @@ async def generate_text_stream(
121133

122134

123135
@router.post(
124-
"/v1/chat/completions",
136+
PATH_OPENAI_COMPLETIONS,
125137
tags=[Tags.Generative.name],
126138
response_model=None,
127139
dependencies=[Depends(cms_globals.props.current_active_user)],
@@ -136,7 +148,7 @@ def generate_chat_completions(
136148
model_service: AbstractModelService = Depends(cms_globals.model_service_dep)
137149
) -> Union[StreamingResponse, JSONResponse]:
138150
"""
139-
Generate chat response based on messages, mimicking OpenAI's /v1/chat/completions endpoint.
151+
Generates chat response based on messages, mimicking OpenAI's /v1/chat/completions endpoint.
140152
141153
Args:
142154
request (Request): The request object.
@@ -153,6 +165,7 @@ def generate_chat_completions(
153165
stream = request_data.stream
154166
max_tokens = request_data.max_tokens
155167
temperature = request_data.temperature
168+
tracking_id = tracking_id or str(uuid.uuid4())
156169

157170
if not messages:
158171
error_response = {
@@ -163,16 +176,25 @@ def generate_chat_completions(
163176
"code": "missing_field",
164177
}
165178
}
166-
return JSONResponse(content=error_response, status_code=HTTP_400_BAD_REQUEST)
179+
return JSONResponse(
180+
content=error_response,
181+
status_code=HTTP_400_BAD_REQUEST,
182+
headers={"x-cms-tracking-id": tracking_id},
183+
)
167184

168-
async def _stream(p: str, mt: int, t: float) -> AsyncGenerator:
185+
async def _stream(prompt: str, max_tokens: int, temperature: float) -> AsyncGenerator:
169186
data = {
170-
"id": tracking_id or str(uuid.uuid4()),
187+
"id": tracking_id,
171188
"object": "chat.completion.chunk",
172189
"choices": [{"delta": {"role": PromptRole.ASSISTANT.value}}],
173190
}
174191
yield f"data: {json.dumps(data)}\n\n"
175-
async for chunk in model_service.generate_async(p, max_tokens=mt, temperature=t):
192+
async for chunk in model_service.generate_async(
193+
prompt,
194+
max_tokens=max_tokens,
195+
temperature=temperature,
196+
report_tokens=partial(_send_usage_metrics, handler=PATH_OPENAI_COMPLETIONS)
197+
):
176198
data = {
177199
"choices": [
178200
{
@@ -188,12 +210,18 @@ async def _stream(p: str, mt: int, t: float) -> AsyncGenerator:
188210
if stream:
189211
return StreamingResponse(
190212
_stream(prompt, max_tokens, temperature),
191-
media_type="text/event-stream"
213+
media_type="text/event-stream",
214+
headers={"x-cms-tracking-id": tracking_id},
192215
)
193216
else:
194-
generated_text = model_service.generate(prompt, max_tokens=max_tokens, temperature=temperature)
217+
generated_text = model_service.generate(
218+
prompt,
219+
max_tokens=max_tokens,
220+
temperature=temperature,
221+
send_metrics=partial(_send_usage_metrics, handler=PATH_OPENAI_COMPLETIONS),
222+
)
195223
completion = OpenAIChatResponse(
196-
id=str(uuid.uuid4()),
224+
id=tracking_id,
197225
object="chat.completion",
198226
created=int(time.time()),
199227
model=model_service.model_name,
@@ -206,10 +234,19 @@ async def _stream(p: str, mt: int, t: float) -> AsyncGenerator:
206234
),
207235
"finish_reason": "stop",
208236
}
209-
]
237+
],
210238
)
211-
return JSONResponse(content=jsonable_encoder(completion))
239+
return JSONResponse(content=jsonable_encoder(completion), headers={"x-cms-tracking-id": tracking_id})
212240

213241

214242
def _empty_prompt_error() -> Iterable[str]:
215243
yield "ERROR: No prompt text provided\n"
244+
245+
246+
def _send_usage_metrics(handler: str, prompt_token_num: int, completion_token_num: int) -> None:
247+
cms_prompt_tokens.labels(handler=handler).observe(prompt_token_num)
248+
logger.debug(f"Sent prompt tokens usage: {prompt_token_num}")
249+
cms_completion_tokens.labels(handler=handler).observe(completion_token_num)
250+
logger.debug(f"Sent completion tokens usage: {completion_token_num}")
251+
cms_total_tokens.labels(handler=handler).observe(prompt_token_num + completion_token_num)
252+
logger.debug(f"Sent total tokens usage: {prompt_token_num + completion_token_num}")

app/domain.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -189,8 +189,8 @@ class OpenAIChatRequest(BaseModel):
189189

190190

191191
class OpenAIChatResponse(BaseModel):
192-
id: str
193-
object: str
194-
created: int
195-
model: str
196-
choices: List
192+
id: str = Field(..., description="The unique identifier for the chat completion request")
193+
object: str = Field(..., description="The type of the response")
194+
created: int = Field(..., description="The timestamp when the completion was generated")
195+
model: str = Field(..., description="The name of the model used for generating the completion")
196+
choices: List = Field(..., description="The generated messages and their metadata")

app/management/prometheus_metrics.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,3 +34,24 @@
3434
"Number of bulk-processed documents",
3535
["handler"],
3636
)
37+
38+
# The histogram metric to track the number of tokens in the messages of the input prompt
39+
cms_prompt_tokens = Histogram(
40+
"cms_prompt_tokens",
41+
"Number of tokens in the messages of the input prompt",
42+
["handler"],
43+
)
44+
45+
# The histogram metric to track the number of tokens in the generated assistant reply
46+
cms_completion_tokens = Histogram(
47+
"cms_completion_tokens",
48+
"Number of tokens in the generated assistant reply",
49+
["handler"],
50+
)
51+
52+
# The histogram metric to track the total number of tokens used in the prompt and the completion
53+
cms_total_tokens = Histogram(
54+
"cms_total_tokens",
55+
"Number of tokens used in the prompt and the completion",
56+
["handler"],
57+
)

app/model_services/huggingface_llm_model.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import logging
33
import asyncio
44
from concurrent.futures import ThreadPoolExecutor
5-
from typing import Dict, List, Optional, Tuple, Any, AsyncIterable
5+
from typing import Dict, List, Optional, Tuple, Any, AsyncIterable, Callable
66
from transformers import (
77
AutoModelForCausalLM,
88
AutoTokenizer,
@@ -198,6 +198,7 @@ def generate(
198198
prompt: str,
199199
max_tokens: int = 512,
200200
temperature: float = 0.7,
201+
report_tokens: Optional[Callable[[str], None]] = None,
201202
**kwargs: Any
202203
) -> str:
203204
"""
@@ -207,6 +208,7 @@ def generate(
207208
prompt (str): The prompt for the text generation
208209
max_tokens (int): The maximum number of tokens to generate. Defaults to 512.
209210
temperature (float): The temperature for the text generation. Defaults to 0.7.
211+
report_tokens (Optional[Callable[[str], None]]): The callback function to send metrics. Defaults to None.
210212
**kwargs (Any): Additional keyword arguments to be passed to this method.
211213
212214
Returns:
@@ -230,17 +232,22 @@ def generate(
230232

231233
outputs = self.model.generate(**generation_kwargs)
232234
generated_text = self.tokenizer.decode(outputs[0], skip_prompt=True, skip_special_tokens=True)
233-
234-
235235
logger.debug("Response generation completed")
236236

237+
if report_tokens:
238+
report_tokens(
239+
prompt_token_num=inputs.input_ids.shape[-1], # type: ignore
240+
completion_token_num=outputs[0].shape[-1], # type: ignore
241+
)
242+
237243
return generated_text
238244

239245
async def generate_async(
240246
self,
241247
prompt: str,
242248
max_tokens: int = 512,
243249
temperature: float = 0.7,
250+
report_tokens: Optional[Callable[[str], None]] = None,
244251
**kwargs: Any
245252
) -> AsyncIterable:
246253
"""
@@ -250,6 +257,7 @@ async def generate_async(
250257
prompt (str): The prompt for the text generation.
251258
max_tokens (int): The maximum number of tokens to generate. Defaults to 512.
252259
temperature (float): The temperature for the text generation. Defaults to 0.7.
260+
report_tokens (Optional[Callable[[str], None]]): The callback function to send metrics. Defaults to None.
253261
**kwargs (Any): Additional keyword arguments to be passed to the model loader.
254262
255263
Returns:
@@ -279,9 +287,20 @@ async def generate_async(
279287

280288
try:
281289
_ = self._text_generator.submit(self.model.generate, **generation_kwargs)
290+
output = ""
282291
for content in streamer:
283292
yield content
293+
output += content
284294
await asyncio.sleep(0.01)
295+
if report_tokens:
296+
report_tokens(
297+
prompt_token_num=inputs.input_ids.shape[-1], # type: ignore
298+
completion_token_num=self.tokenizer( # type: ignore
299+
output,
300+
add_special_tokens=False,
301+
return_tensors="pt"
302+
).input_ids.shape[-1],
303+
)
285304
except Exception as e:
286305
logger.error("An error occurred while generating the response")
287306
logger.exception(e)

app/utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -694,7 +694,7 @@ def dump_pydantic_object_to_dict(model: BaseModel) -> Dict:
694694
"""
695695

696696
if hasattr(model, "model_dump"):
697-
return model.model_dump() # type: ignore
697+
return model.model_dump(mode="json") # type: ignore
698698
elif hasattr(model, "dict"):
699699
return model.dict() # type: ignore
700700
else:
@@ -835,3 +835,4 @@ def get_prompt_from_messages(tokenizer: PreTrainedTokenizer, messages: List[Prom
835835
"25624495": '© 2002-2020 International Health Terminology Standards Development Organisation (IHTSDO). All rights reserved. SNOMED CT®, was originally created by The College of American Pathologists. "SNOMED" and "SNOMED CT" are registered trademarks of the IHTSDO.',
836836
"55540447": "linkage concept"
837837
}
838+

tests/app/model_services/test_huggingface_llm_model.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ def test_generate(huggingface_llm_model):
4646
huggingface_llm_model.init_model()
4747
huggingface_llm_model.model = MagicMock()
4848
huggingface_llm_model.tokenizer = MagicMock()
49+
mock_send_metrics = MagicMock()
4950
inputs = MagicMock()
5051
inputs.input_ids = MagicMock(shape=[1, 2])
5152
inputs.attention_mask = MagicMock()
@@ -58,6 +59,7 @@ def test_generate(huggingface_llm_model):
5859
prompt="Alright?",
5960
max_tokens=128,
6061
temperature=0.5,
62+
report_tokens=mock_send_metrics
6163
)
6264

6365
huggingface_llm_model.tokenizer.assert_called_once_with(
@@ -78,13 +80,18 @@ def test_generate(huggingface_llm_model):
7880
skip_prompt=True,
7981
skip_special_tokens=True,
8082
)
83+
mock_send_metrics.assert_called_once_with(
84+
prompt_token_num=2,
85+
completion_token_num=2,
86+
)
8187
assert result == "Yeah."
8288

8389

8490
async def test_generate_async(huggingface_llm_model):
8591
huggingface_llm_model.init_model()
8692
huggingface_llm_model.model = MagicMock()
8793
huggingface_llm_model.tokenizer = MagicMock()
94+
mock_send_metrics = MagicMock()
8895
inputs = MagicMock()
8996
inputs.input_ids = MagicMock(shape=[1, 2])
9097
inputs.attention_mask = MagicMock()
@@ -97,6 +104,7 @@ async def test_generate_async(huggingface_llm_model):
97104
prompt="Alright?",
98105
max_tokens=128,
99106
temperature=0.5,
107+
report_tokens=mock_send_metrics
100108
)
101109

102110
huggingface_llm_model.tokenizer.assert_called_once_with(
@@ -117,4 +125,8 @@ async def test_generate_async(huggingface_llm_model):
117125
skip_prompt=True,
118126
skip_special_tokens=True,
119127
)
128+
mock_send_metrics.assert_called_once_with(
129+
prompt_token_num=2,
130+
completion_token_num=2,
131+
)
120132
assert result == "Yeah."

0 commit comments

Comments
 (0)