1
1
import argparse
2
- import dataclasses
3
2
import inspect
4
3
import time
5
4
import uuid
37
36
from vllm .tgis_utils import logs
38
37
from vllm .tgis_utils .logits_processors import (ExpDecayLengthPenaltyWarper ,
39
38
TypicalLogitsWarperWrapper )
39
+ from vllm .tgis_utils .metrics import (FailureReasonLabel , ServiceMetrics ,
40
+ TGISStatLogger )
40
41
from vllm .transformers_utils .tokenizer_group import BaseTokenizerGroup
41
42
42
43
logger = init_logger (__name__ )
43
-
44
- @dataclasses .dataclass
45
- class Times :
46
- """Container tracking times (in seconds) when requests start and finish """
47
- # When control enters Generate or GenerateStream
48
- request_start : float
49
- # When the request is sent to the vLLM engine
50
- engine_start : float = 0
51
- # When the stream from the vLLM engine closes
52
- end : float = 0
44
+ service_metrics = ServiceMetrics ()
53
45
54
46
55
47
def with_default (value : Any , default : Any ) -> Any :
@@ -63,7 +55,13 @@ async def _handle_exception(e: Exception, func, *args, **kwargs):
63
55
if type (e ).__name__ == "torch.cuda.OutOfMemoryError" : #TODO check
64
56
context = kwargs .get ("context" , None ) or args [- 1 ]
65
57
logger .exception (f"{ func .__name__ } caused GPU OOM error" )
58
+ service_metrics .count_request_failure (FailureReasonLabel .OOM )
66
59
await context .abort (StatusCode .RESOURCE_EXHAUSTED , str (e ))
60
+ else :
61
+ if "generate" in func .__name__ .lower ():
62
+ service_metrics .count_request_failure (FailureReasonLabel .GENERATE )
63
+ else :
64
+ service_metrics .count_request_failure (FailureReasonLabel .UNKNOWN )
67
65
logger .exception (f"{ func .__name__ } failed" )
68
66
raise e
69
67
@@ -108,10 +106,20 @@ async def _post_init(self):
108
106
self .tokenizer_group = await self .engine .get_tokenizer_group ()
109
107
self .tokenizer = await self .engine .get_tokenizer ()
110
108
109
+ # Swap in the special TGIS stats logger
110
+ vllm_stat_logger = self .engine .engine .stat_logger
111
+ tgis_stats_logger = TGISStatLogger (
112
+ vllm_stat_logger = vllm_stat_logger ,
113
+ max_sequence_len = self .config .max_model_len )
114
+ # 🌶️🌶️🌶️ sneaky sneak
115
+ self .engine .engine .stat_logger = tgis_stats_logger
116
+
117
+
111
118
@log_rpc_handler_errors
112
119
async def Generate (self , request : BatchedGenerationRequest ,
113
120
context : ServicerContext ) -> BatchedGenerationResponse :
114
121
start_time = time .time ()
122
+ service_metrics .count_generate_request (len (request .requests ))
115
123
request_id = self .request_id (context )
116
124
sampling_params , deadline = await self ._validate_and_convert_params (
117
125
request .params , context )
@@ -120,23 +128,19 @@ async def Generate(self, request: BatchedGenerationRequest,
120
128
request_count = len (request .requests )
121
129
122
130
generators = []
123
- timing_infos = []
124
131
max_is_token_limit = [False ] * request_count
125
132
for i , req in enumerate (request .requests ):
126
133
input_ids , max_is_token_limit [i ]\
127
134
= await self ._validate_prompt_and_tokenize (
128
135
sampling_params , truncate_input_tokens , req .text , context )
129
- timing_info = Times (request_start = start_time )
130
- timing_infos .append (timing_info )
131
136
generators .append (
132
- self .timed_generator (
133
- # prompt is supplied for observability, the text is not
134
- # re-tokenized when `prompt_token_ids` is supplied
135
- self .engine .generate (prompt = req .text ,
136
- sampling_params = sampling_params ,
137
- request_id = f"{ request_id } -{ i } " ,
138
- prompt_token_ids = input_ids ),
139
- timing_info ))
137
+ # prompt is supplied for observability, the text is not
138
+ # re-tokenized when `prompt_token_ids` is supplied
139
+ self .engine .generate (prompt = req .text ,
140
+ sampling_params = sampling_params ,
141
+ request_id = f"{ request_id } -{ i } " ,
142
+ prompt_token_ids = input_ids ),
143
+ )
140
144
141
145
# TODO handle cancellation
142
146
result_generator : AsyncIterator [Tuple [
@@ -151,6 +155,7 @@ async def Generate(self, request: BatchedGenerationRequest,
151
155
# await self.engine.abort(f"{request_id}-{i}")
152
156
# return self.create_error_response("Client disconnected")
153
157
responses [i ] = res
158
+ service_metrics .observe_queue_time (res )
154
159
155
160
if deadline is not None and time .time (
156
161
) >= deadline and None not in responses :
@@ -173,7 +178,8 @@ async def Generate(self, request: BatchedGenerationRequest,
173
178
kind_log = f"Sub-request { i } from batch of { request_count } "
174
179
175
180
self ._log_unary_response (request = request , response = response ,
176
- times = timing_infos [i ], kind_log = kind_log )
181
+ start_time = start_time , engine_response = res ,
182
+ kind_log = kind_log )
177
183
responses [i ] = response
178
184
179
185
return BatchedGenerationResponse (responses = responses )
@@ -182,7 +188,8 @@ async def Generate(self, request: BatchedGenerationRequest,
182
188
async def GenerateStream (
183
189
self , request : SingleGenerationRequest ,
184
190
context : ServicerContext ) -> AsyncIterator [GenerationResponse ]:
185
- timing_info = Times (request_start = time .time ())
191
+ start_time = time .time ()
192
+ service_metrics .count_generate_request ()
186
193
request_id = self .request_id (context )
187
194
sampling_params , deadline = await self ._validate_and_convert_params (
188
195
request .params , context )
@@ -193,16 +200,13 @@ async def GenerateStream(
193
200
sampling_params , truncate_input_tokens , request .request .text ,
194
201
context )
195
202
196
- result_generator = self .timed_generator (
197
- self .engine .generate (
198
- # prompt is supplied for observability, the text is not
199
- # re-tokenized when `prompt_token_ids` is supplied
200
- prompt = request .request .text ,
201
- sampling_params = sampling_params ,
202
- request_id = request_id ,
203
- prompt_token_ids = input_ids ,
204
- ),
205
- timing_info
203
+ result_generator = self .engine .generate (
204
+ # prompt is supplied for observability, the text is not
205
+ # re-tokenized when `prompt_token_ids` is supplied
206
+ prompt = request .request .text ,
207
+ sampling_params = sampling_params ,
208
+ request_id = request_id ,
209
+ prompt_token_ids = input_ids ,
206
210
)
207
211
208
212
resp_options = request .params .response
@@ -213,9 +217,12 @@ async def GenerateStream(
213
217
last_token_count = 0
214
218
time_limit_reached = False
215
219
full_output = ""
220
+ last_engine_response = None
216
221
#TODO handle cancellation
217
222
async for result in result_generator :
223
+ last_engine_response = result
218
224
if first :
225
+ service_metrics .observe_queue_time (result )
219
226
first_response = self ._convert_input_details (
220
227
result , resp_options , sampling_params ,
221
228
GenerationResponse ())
@@ -247,7 +254,8 @@ async def GenerateStream(
247
254
first_response .text = full_output
248
255
first_response .generated_token_count = last_token_count
249
256
self ._log_streaming_response (request = request , response = first_response ,
250
- times = timing_info )
257
+ start_time = start_time ,
258
+ engine_response = last_engine_response )
251
259
252
260
def _convert_input_details (
253
261
self , result : RequestOutput , resp_options : ResponseOptions ,
@@ -314,6 +322,7 @@ async def _validate_and_convert_params(
314
322
try :
315
323
validate_params (params , self .max_max_new_tokens )
316
324
except ValueError as tgis_validation_error :
325
+ service_metrics .count_request_failure (FailureReasonLabel .VALIDATION )
317
326
await context .abort (StatusCode .INVALID_ARGUMENT ,
318
327
str (tgis_validation_error ))
319
328
@@ -396,6 +405,7 @@ async def _validate_and_convert_params(
396
405
except ValueError as vllm_validation_error :
397
406
# There may be validation cases caught by vLLM that are not covered
398
407
# by the TGIS api validation
408
+ service_metrics .count_request_failure (FailureReasonLabel .VALIDATION )
399
409
await context .abort (StatusCode .INVALID_ARGUMENT ,
400
410
str (vllm_validation_error ))
401
411
@@ -528,36 +538,32 @@ async def _validate_prompt_and_tokenize(
528
538
529
539
@staticmethod
530
540
def _log_unary_response (request : BatchedGenerationRequest ,
531
- response : GenerationResponse , times : Times ,
532
- kind_log : str ):
541
+ response : GenerationResponse ,
542
+ engine_response : RequestOutput ,
543
+ start_time : float , kind_log : str ):
533
544
logs .log_response (inputs = [r .text for r in request .requests ],
534
545
response = response , params = request .params ,
535
- prefix_id = request .prefix_id , times = times ,
536
- kind_log = kind_log , method_str = "generate" ,
537
- logger = logger )
546
+ prefix_id = request .prefix_id ,
547
+ engine_response = engine_response ,
548
+ start_time = start_time , kind_log = kind_log ,
549
+ method_str = "generate" , logger = logger )
538
550
539
551
@staticmethod
540
552
def _log_streaming_response (request : SingleGenerationRequest ,
541
- response : GenerationResponse , times : Times ):
553
+ response : GenerationResponse ,
554
+ engine_response : RequestOutput ,
555
+ start_time : float ):
542
556
logs .log_response (inputs = [request .request .text ], response = response ,
543
557
params = request .params , prefix_id = request .prefix_id ,
544
- times = times , kind_log = "Streaming response" ,
558
+ engine_response = engine_response ,
559
+ start_time = start_time , kind_log = "Streaming response" ,
545
560
method_str = "generate_stream" , logger = logger )
546
561
547
562
548
- @staticmethod
549
- async def timed_generator (generator : AsyncIterator [RequestOutput ],
550
- times : Times ) -> AsyncIterator [RequestOutput ]:
551
- """Injects some timing data around each result generator from the
552
- LLMEngine"""
553
- times .engine_start = time .time ()
554
- async for val in generator :
555
- yield val
556
- times .end = time .time ()
557
-
558
563
@log_rpc_handler_errors
559
564
async def Tokenize (self , request : BatchedTokenizeRequest ,
560
565
context : ServicerContext ) -> BatchedTokenizeResponse :
566
+ service_metrics .observe_tokenization_request (request )
561
567
#TODO implement these
562
568
if request .return_offsets :
563
569
await context .abort (StatusCode .INVALID_ARGUMENT ,
@@ -578,7 +584,9 @@ async def Tokenize(self, request: BatchedTokenizeRequest,
578
584
tokens = None if not request .return_tokens else
579
585
self .tokenizer .convert_ids_to_tokens (token_ids )))
580
586
581
- return BatchedTokenizeResponse (responses = responses )
587
+ response = BatchedTokenizeResponse (responses = responses )
588
+ service_metrics .observe_tokenization_response (response )
589
+ return response
582
590
583
591
@log_rpc_handler_errors
584
592
async def ModelInfo (self , request : ModelInfoRequest ,
0 commit comments