Skip to content
This repository was archived by the owner on Sep 4, 2025. It is now read-only.

Commit 8c548e4

Browse files
authored
TGIS metrics (#18)
This PR implements a subset of the metrics from the TGIS image. I tried to make sure that everything from our current ops dashboard is supported. These are: - tgi_tokenize_request_tokens - tgi_tokenize_request_input_count - tgi_request_input_count - tgi_request_failure - tgi_request_queue_duration - tgi_queue_size - tgi_batch_current_size - tgi_batch_inference_duration - tgi_request_input_length - tgi_request_generated_tokens --------- Signed-off-by: Joe Runde <[email protected]>
1 parent 1613074 commit 8c548e4

File tree

3 files changed

+209
-60
lines changed

3 files changed

+209
-60
lines changed

vllm/entrypoints/grpc/grpc_server.py

Lines changed: 61 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import argparse
2-
import dataclasses
32
import inspect
43
import time
54
import uuid
@@ -37,19 +36,12 @@
3736
from vllm.tgis_utils import logs
3837
from vllm.tgis_utils.logits_processors import (ExpDecayLengthPenaltyWarper,
3938
TypicalLogitsWarperWrapper)
39+
from vllm.tgis_utils.metrics import (FailureReasonLabel, ServiceMetrics,
40+
TGISStatLogger)
4041
from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
4142

4243
logger = init_logger(__name__)
43-
44-
@dataclasses.dataclass
45-
class Times:
46-
"""Container tracking times (in seconds) when requests start and finish """
47-
# When control enters Generate or GenerateStream
48-
request_start: float
49-
# When the request is sent to the vLLM engine
50-
engine_start: float = 0
51-
# When the stream from the vLLM engine closes
52-
end: float = 0
44+
service_metrics = ServiceMetrics()
5345

5446

5547
def with_default(value: Any, default: Any) -> Any:
@@ -63,7 +55,13 @@ async def _handle_exception(e: Exception, func, *args, **kwargs):
6355
if type(e).__name__ == "torch.cuda.OutOfMemoryError": #TODO check
6456
context = kwargs.get("context", None) or args[-1]
6557
logger.exception(f"{func.__name__} caused GPU OOM error")
58+
service_metrics.count_request_failure(FailureReasonLabel.OOM)
6659
await context.abort(StatusCode.RESOURCE_EXHAUSTED, str(e))
60+
else:
61+
if "generate" in func.__name__.lower():
62+
service_metrics.count_request_failure(FailureReasonLabel.GENERATE)
63+
else:
64+
service_metrics.count_request_failure(FailureReasonLabel.UNKNOWN)
6765
logger.exception(f"{func.__name__} failed")
6866
raise e
6967

@@ -108,10 +106,20 @@ async def _post_init(self):
108106
self.tokenizer_group = await self.engine.get_tokenizer_group()
109107
self.tokenizer = await self.engine.get_tokenizer()
110108

109+
# Swap in the special TGIS stats logger
110+
vllm_stat_logger = self.engine.engine.stat_logger
111+
tgis_stats_logger = TGISStatLogger(
112+
vllm_stat_logger=vllm_stat_logger,
113+
max_sequence_len=self.config.max_model_len)
114+
# 🌶️🌶️🌶️ sneaky sneak
115+
self.engine.engine.stat_logger = tgis_stats_logger
116+
117+
111118
@log_rpc_handler_errors
112119
async def Generate(self, request: BatchedGenerationRequest,
113120
context: ServicerContext) -> BatchedGenerationResponse:
114121
start_time = time.time()
122+
service_metrics.count_generate_request(len(request.requests))
115123
request_id = self.request_id(context)
116124
sampling_params, deadline = await self._validate_and_convert_params(
117125
request.params, context)
@@ -120,23 +128,19 @@ async def Generate(self, request: BatchedGenerationRequest,
120128
request_count = len(request.requests)
121129

122130
generators = []
123-
timing_infos = []
124131
max_is_token_limit = [False] * request_count
125132
for i, req in enumerate(request.requests):
126133
input_ids, max_is_token_limit[i]\
127134
= await self._validate_prompt_and_tokenize(
128135
sampling_params, truncate_input_tokens, req.text, context)
129-
timing_info = Times(request_start=start_time)
130-
timing_infos.append(timing_info)
131136
generators.append(
132-
self.timed_generator(
133-
# prompt is supplied for observability, the text is not
134-
# re-tokenized when `prompt_token_ids` is supplied
135-
self.engine.generate(prompt=req.text,
136-
sampling_params=sampling_params,
137-
request_id=f"{request_id}-{i}",
138-
prompt_token_ids=input_ids),
139-
timing_info))
137+
# prompt is supplied for observability, the text is not
138+
# re-tokenized when `prompt_token_ids` is supplied
139+
self.engine.generate(prompt=req.text,
140+
sampling_params=sampling_params,
141+
request_id=f"{request_id}-{i}",
142+
prompt_token_ids=input_ids),
143+
)
140144

141145
# TODO handle cancellation
142146
result_generator: AsyncIterator[Tuple[
@@ -151,6 +155,7 @@ async def Generate(self, request: BatchedGenerationRequest,
151155
# await self.engine.abort(f"{request_id}-{i}")
152156
# return self.create_error_response("Client disconnected")
153157
responses[i] = res
158+
service_metrics.observe_queue_time(res)
154159

155160
if deadline is not None and time.time(
156161
) >= deadline and None not in responses:
@@ -173,7 +178,8 @@ async def Generate(self, request: BatchedGenerationRequest,
173178
kind_log = f"Sub-request {i} from batch of {request_count}"
174179

175180
self._log_unary_response(request=request, response=response,
176-
times=timing_infos[i], kind_log=kind_log)
181+
start_time=start_time, engine_response=res,
182+
kind_log=kind_log)
177183
responses[i] = response
178184

179185
return BatchedGenerationResponse(responses=responses)
@@ -182,7 +188,8 @@ async def Generate(self, request: BatchedGenerationRequest,
182188
async def GenerateStream(
183189
self, request: SingleGenerationRequest,
184190
context: ServicerContext) -> AsyncIterator[GenerationResponse]:
185-
timing_info = Times(request_start=time.time())
191+
start_time = time.time()
192+
service_metrics.count_generate_request()
186193
request_id = self.request_id(context)
187194
sampling_params, deadline = await self._validate_and_convert_params(
188195
request.params, context)
@@ -193,16 +200,13 @@ async def GenerateStream(
193200
sampling_params, truncate_input_tokens, request.request.text,
194201
context)
195202

196-
result_generator = self.timed_generator(
197-
self.engine.generate(
198-
# prompt is supplied for observability, the text is not
199-
# re-tokenized when `prompt_token_ids` is supplied
200-
prompt=request.request.text,
201-
sampling_params=sampling_params,
202-
request_id=request_id,
203-
prompt_token_ids=input_ids,
204-
),
205-
timing_info
203+
result_generator = self.engine.generate(
204+
# prompt is supplied for observability, the text is not
205+
# re-tokenized when `prompt_token_ids` is supplied
206+
prompt=request.request.text,
207+
sampling_params=sampling_params,
208+
request_id=request_id,
209+
prompt_token_ids=input_ids,
206210
)
207211

208212
resp_options = request.params.response
@@ -213,9 +217,12 @@ async def GenerateStream(
213217
last_token_count = 0
214218
time_limit_reached = False
215219
full_output = ""
220+
last_engine_response = None
216221
#TODO handle cancellation
217222
async for result in result_generator:
223+
last_engine_response = result
218224
if first:
225+
service_metrics.observe_queue_time(result)
219226
first_response = self._convert_input_details(
220227
result, resp_options, sampling_params,
221228
GenerationResponse())
@@ -247,7 +254,8 @@ async def GenerateStream(
247254
first_response.text = full_output
248255
first_response.generated_token_count = last_token_count
249256
self._log_streaming_response(request=request, response=first_response,
250-
times=timing_info)
257+
start_time=start_time,
258+
engine_response=last_engine_response)
251259

252260
def _convert_input_details(
253261
self, result: RequestOutput, resp_options: ResponseOptions,
@@ -314,6 +322,7 @@ async def _validate_and_convert_params(
314322
try:
315323
validate_params(params, self.max_max_new_tokens)
316324
except ValueError as tgis_validation_error:
325+
service_metrics.count_request_failure(FailureReasonLabel.VALIDATION)
317326
await context.abort(StatusCode.INVALID_ARGUMENT,
318327
str(tgis_validation_error))
319328

@@ -396,6 +405,7 @@ async def _validate_and_convert_params(
396405
except ValueError as vllm_validation_error:
397406
# There may be validation cases caught by vLLM that are not covered
398407
# by the TGIS api validation
408+
service_metrics.count_request_failure(FailureReasonLabel.VALIDATION)
399409
await context.abort(StatusCode.INVALID_ARGUMENT,
400410
str(vllm_validation_error))
401411

@@ -528,36 +538,32 @@ async def _validate_prompt_and_tokenize(
528538

529539
@staticmethod
530540
def _log_unary_response(request: BatchedGenerationRequest,
531-
response: GenerationResponse, times: Times,
532-
kind_log: str):
541+
response: GenerationResponse,
542+
engine_response: RequestOutput,
543+
start_time: float, kind_log: str):
533544
logs.log_response(inputs=[r.text for r in request.requests],
534545
response=response, params=request.params,
535-
prefix_id=request.prefix_id, times=times,
536-
kind_log=kind_log, method_str="generate",
537-
logger=logger)
546+
prefix_id=request.prefix_id,
547+
engine_response=engine_response,
548+
start_time=start_time, kind_log=kind_log,
549+
method_str="generate", logger=logger)
538550

539551
@staticmethod
540552
def _log_streaming_response(request: SingleGenerationRequest,
541-
response: GenerationResponse, times: Times):
553+
response: GenerationResponse,
554+
engine_response: RequestOutput,
555+
start_time: float):
542556
logs.log_response(inputs=[request.request.text], response=response,
543557
params=request.params, prefix_id=request.prefix_id,
544-
times=times, kind_log="Streaming response",
558+
engine_response=engine_response,
559+
start_time=start_time, kind_log="Streaming response",
545560
method_str="generate_stream", logger=logger)
546561

547562

548-
@staticmethod
549-
async def timed_generator(generator: AsyncIterator[RequestOutput],
550-
times: Times) -> AsyncIterator[RequestOutput]:
551-
"""Injects some timing data around each result generator from the
552-
LLMEngine"""
553-
times.engine_start = time.time()
554-
async for val in generator:
555-
yield val
556-
times.end = time.time()
557-
558563
@log_rpc_handler_errors
559564
async def Tokenize(self, request: BatchedTokenizeRequest,
560565
context: ServicerContext) -> BatchedTokenizeResponse:
566+
service_metrics.observe_tokenization_request(request)
561567
#TODO implement these
562568
if request.return_offsets:
563569
await context.abort(StatusCode.INVALID_ARGUMENT,
@@ -578,7 +584,9 @@ async def Tokenize(self, request: BatchedTokenizeRequest,
578584
tokens=None if not request.return_tokens else
579585
self.tokenizer.convert_ids_to_tokens(token_ids)))
580586

581-
return BatchedTokenizeResponse(responses=responses)
587+
response = BatchedTokenizeResponse(responses=responses)
588+
service_metrics.observe_tokenization_response(response)
589+
return response
582590

583591
@log_rpc_handler_errors
584592
async def ModelInfo(self, request: ModelInfoRequest,

vllm/tgis_utils/logs.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,19 +4,23 @@
44

55
from google.protobuf import text_format
66

7+
from vllm import RequestOutput
78
from vllm.entrypoints.grpc.pb.generation_pb2 import (GenerationResponse,
89
Parameters, StopReason)
910

1011

1112
def log_response(inputs: List[str], params: Parameters, prefix_id: str,
12-
response: GenerationResponse, times, kind_log: str,
13-
method_str: str, logger: logging.Logger):
13+
response: GenerationResponse, engine_response: RequestOutput,
14+
start_time: float, kind_log: str, method_str: str,
15+
logger: logging.Logger):
1416
"""Logs responses similar to how the TGIS server does"""
1517
# This time contains both request validation and tokenization
16-
tokenization_time = times.engine_start - times.request_start
17-
llm_engine_time = times.end - times.engine_start
18-
time_per_token = _safe_div(llm_engine_time, response.generated_token_count)
19-
total_time = times.end - times.request_start
18+
tokenization_time = engine_response.metrics.arrival_time - start_time
19+
inference_time = (engine_response.metrics.last_token_time -
20+
engine_response.metrics.first_scheduled_time)
21+
queue_time = engine_response.metrics.time_in_queue
22+
time_per_token = _safe_div(inference_time, response.generated_token_count)
23+
total_time = engine_response.metrics.last_token_time - start_time
2024
output_len = len(response.text)
2125
short_output = _truncate(response.text, 32)
2226
short_input = [_truncate(input_, 32) for input_ in inputs]
@@ -26,7 +30,8 @@ def log_response(inputs: List[str], params: Parameters, prefix_id: str,
2630
span_str = (f"{method_str}{{input={short_input} prefix_id={prefix_id} "
2731
f"input_chars=[{input_chars}] params={paramstr} "
2832
f"tokenization_time={tokenization_time * 1e3:.2f}ms "
29-
f"queue_and_inference_time={llm_engine_time * 1e3:.2f}ms "
33+
f"queue_time={queue_time * 1e3:.2f}ms "
34+
f"inference_time={inference_time * 1e3:.2f}ms "
3035
f"time_per_token={time_per_token * 1e3:.2f}ms "
3136
f"total_time={total_time * 1e3:.2f}ms "
3237
f"input_toks={response.input_token_count}}}")

0 commit comments

Comments
 (0)