1515from ..disaggregated_params import DisaggregatedParams
1616from ..llmapi .tracer import global_tracer
1717from ..llmapi .utils import AsyncQueue
18+ from ..metrics import MetricNames , MetricsCollector , RequestEventTiming
1819from ..sampling_params import LogprobParams , SamplingParams
1920from .utils import ErrorResponse , has_event_loop , is_llm_response
2021
@@ -50,14 +51,18 @@ class LogProbsResult(NamedTuple):
5051
5152
5253class ResponseWrapper :
53- """Wrapper of runtime response with optional outputs computed post runtime.
54+ """
55+ 1. Wrapper of runtime response with optional outputs computed post runtime.
56+ 2. A workaround to pass around RequestPerfMetrics.
5457 """
5558
5659 def __init__ (self ,
5760 response : Union ["PostprocWorker.Output" , tllm .Response ],
58- logprobs : Optional [LogProbsResult ] = None ):
61+ logprobs : Optional [LogProbsResult ] = None ,
62+ request_perf_metrics : Optional [dict [str , float ]] = None ):
5963 self ._response = response
6064 self .logprobs = logprobs
65+ self .request_perf_metrics = request_perf_metrics
6166
6267 @property
6368 def _is_llm_response (self ):
@@ -68,6 +73,14 @@ def __getattr__(self, name):
6873 response = object .__getattribute__ (self , '_response' )
6974 return getattr (response , name )
7075
76+ def __getstate__ (self ):
77+ return (self ._response , self .logprobs , self .request_perf_metrics )
78+
79+ def __setstate__ (self , state ):
80+ self ._response = state [0 ]
81+ self .logprobs = state [1 ]
82+ self .request_perf_metrics = state [2 ]
83+
7184
7285@dataclass (slots = True )
7386class CompletionOutput :
@@ -146,6 +159,7 @@ def __init__(self,
146159 self .disaggregated_params = None
147160 self .decoding_iter = 0
148161 self ._done = False
162+ self .metrics_dict = {}
149163
150164 if has_event_loop ():
151165 self .aqueue = AsyncQueue ()
@@ -201,7 +215,9 @@ def _handle_sequence(self,
201215 finish_reasons ,
202216 response_tensors ,
203217 sequence_index ,
204- logprobs_result = None ):
218+ logprobs_result = None ,
219+ req_perf_metrics_dict : Optional [dict [str ,
220+ float ]] = None ):
205221 """ Handle a single sequence in the response. """
206222
207223 seq_idx = sequence_index
@@ -271,14 +287,17 @@ def _handle_sequence(self,
271287 else :
272288 raise ValueError (
273289 f"Unknown finish reason: { finish_reasons [src_idx ]} " )
290+ self .record_stats (output , req_perf_metrics_dict )
274291
275292 @nvtx_range_debug ("handle_response" ,
276293 color = "red" ,
277294 category = "GenerationResultBase" )
278295 def _handle_response (self ,
279296 response : Union ["PostprocWorker.Output" , tllm .Response ,
280297 ResponseWrapper , ErrorResponse ]):
298+ req_perf_metrics_dict = None
281299 if isinstance (response , ResponseWrapper ):
300+ req_perf_metrics_dict = response .request_perf_metrics
282301 logprobs_result = response .logprobs
283302 response = response ._response
284303 else :
@@ -291,6 +310,8 @@ def _handle_response(self,
291310 self ._outputs [0 ] = response .res
292311 else :
293312 self ._outputs [0 ]._postprocess_result = response .res
313+ if response .metrics :
314+ self .metrics_dict = response .metrics
294315
295316 if response .error :
296317 if self ._background_error_handler is not None and (
@@ -303,7 +324,8 @@ def _handle_response(self,
303324 handler (response .error_msg )
304325
305326 response_result = response .result
306- if hasattr (response_result , "_result" ):
327+ if hasattr (response_result , "_result" ) and isinstance (
328+ response_result ._result , bytes ):
307329 response_result .deserialize ()
308330
309331 self ._done = response_result .is_final
@@ -322,11 +344,12 @@ def _handle_response(self,
322344 if self .sampling_params .use_beam_search :
323345 for beam_idx , _ in enumerate (response_result .output_token_ids ):
324346 self ._handle_sequence (finish_reasons , response_result ,
325- beam_idx , logprobs_result )
347+ beam_idx , logprobs_result ,
348+ req_perf_metrics_dict )
326349 else :
327350 self ._handle_sequence (finish_reasons , response_result ,
328351 response_result .sequence_index ,
329- logprobs_result )
352+ logprobs_result , req_perf_metrics_dict )
330353
331354 if response_result .context_logits is not None :
332355 self ._context_logits = response_result .context_logits
@@ -342,6 +365,29 @@ def _handle_response(self,
342365 else :
343366 raise ValueError (f"Unknown response type: { response } " )
344367
368+ def record_stats (self ,
369+ output : CompletionOutput ,
370+ stats : Optional [dict [str , float ]] = None ) -> None :
371+ """Record the stats of the generation result.
372+
373+ Args:
374+ output (CompletionOutput): The output of the generation result.
375+ stats (Optional[dict[str, float]]): The stats of the generation result. Defaults to None.
376+ """
377+ if not stats :
378+ return
379+ metrics_stats = {}
380+ if output .finish_reason :
381+ metrics_stats .update ({
382+ MetricsCollector .labelname_finish_reason :
383+ output .finish_reason
384+ })
385+ processed_metrics_stat = _process_req_perf_metrics (
386+ stats , len (output .token_ids ), self .sampling_params .n > 1 )
387+ if processed_metrics_stat :
388+ metrics_stats .update (processed_metrics_stat )
389+ self .metrics_dict = metrics_stats
390+
345391
346392class DetokenizedGenerationResultBase (GenerationResultBase ):
347393 ''' The base class for the generation result with detokenization support. '''
@@ -688,3 +734,30 @@ def _topk_logprobs(logits: torch.Tensor, top_k: int,
688734
689735 return LogProbsResult (prompt = prompt_logprobs ,
690736 generation = generation_logprobs )
737+
738+
739+ def _process_req_perf_metrics (
740+ req_perf_metrics_dict : Optional [dict [str , float ]],
741+ output_length : int ,
742+ is_multiple_response : bool = False ) -> dict [MetricNames , float ]:
743+ stat = {}
744+ if not req_perf_metrics_dict :
745+ return stat
746+ ttft = req_perf_metrics_dict .get (RequestEventTiming .FIRST_TOKEN_TIME , 0 ) - \
747+ req_perf_metrics_dict .get (RequestEventTiming .ARRIVAL_TIME , 0 )
748+ e2e = req_perf_metrics_dict .get (RequestEventTiming .LAST_TOKEN_TIME , 0 ) - \
749+ req_perf_metrics_dict .get (RequestEventTiming .ARRIVAL_TIME , 0 )
750+ request_queue_time = req_perf_metrics_dict .get (RequestEventTiming .FIRST_SCHEDULED_TIME , 0 ) - \
751+ req_perf_metrics_dict .get (RequestEventTiming .ARRIVAL_TIME , 0 )
752+ stat = {
753+ MetricNames .TTFT : ttft ,
754+ MetricNames .E2E : e2e ,
755+ MetricNames .REQUEST_QUEUE_TIME : request_queue_time
756+ }
757+ if output_length > 1 and not is_multiple_response :
758+ tpot = (req_perf_metrics_dict .get (
759+ RequestEventTiming .LAST_TOKEN_TIME , 0 ) - req_perf_metrics_dict .get (
760+ RequestEventTiming .FIRST_TOKEN_TIME , 0 )) / (output_length - 1 )
761+ stat .update ({MetricNames .TPOT : tpot })
762+ stat = dict (filter (lambda item : item [1 ] > 0 , stat .items ()))
763+ return stat
0 commit comments