@@ -159,7 +159,9 @@ def _post_deserialize(self, obj: dict):
159159class DynamicInferenceEventType (Enum ):
160160 """Dynamic inference event type."""
161161
162- ADD = auto ()
162+ ADD_ENGINE = auto () # When request is added to engine via _add_request()
163+ ADD_CONTEXT = auto () # When request is added to context (scheduled for prefill)
164+ GENERATED_TOKEN = auto () # When an output token is generated (payload = {"token_id": int})
163165 PAUSE = auto ()
164166 EVICT = auto ()
165167 FINISH = auto ()
@@ -202,33 +204,46 @@ def __post_init__(self):
202204 DynamicInferenceEventType .ERROR_NONTRANSIENT ,
203205 ):
204206 assert self .payload is not None
207+ elif self .type == DynamicInferenceEventType .GENERATED_TOKEN :
208+ assert (
209+ self .payload is not None
210+ and isinstance (self .payload , dict )
211+ and "token_id" in self .payload
212+ )
205213 else :
206214 assert self .payload is None
207215
208216 def __str__ (self ):
209- payload_str = "" if self .payload is None else f", { type (self .payload ).__name__ } "
217+ if self .type == DynamicInferenceEventType .GENERATED_TOKEN :
218+ payload_str = f", token={ self .payload ['token_id' ]} "
219+ elif self .payload is None :
220+ payload_str = ""
221+ else :
222+ payload_str = f", { type (self .payload ).__name__ } "
210223 return f"[{ self .timestamp :.3f} ] { self .type .name } { payload_str } "
211224
212225 def serialize (self ) -> dict :
213226 """Converts the instance into a serializable dictionary.
214227
215228 Returns:
216- (dict) A dictionary representation of the instance suitable for
217- serialization.
229+ dict: Full event dict.
218230 """
219-
220- # Dataclass to dict.
221231 torch .cuda .nvtx .range_push ("DynamicInferenceEvent.serialize" )
222232 # do not use asdict(self) - it has very high CPU overheads
223233 # and if there are tensors, it will try to deepcopy them
224234 obj = self .__dict__ .copy ()
225235 obj ["type" ] = self .type .name
226236
227237 # Serialize payload.
228- if self .payload :
229- from .contexts .dynamic_context import ContextErrorFactory # avoid circular import.
238+ if self .payload is not None :
239+ if self .type in (
240+ DynamicInferenceEventType .ERROR_TRANSIENT ,
241+ DynamicInferenceEventType .ERROR_NONTRANSIENT ,
242+ ):
243+ from .contexts .dynamic_context import ContextErrorFactory # avoid circular import.
244+
245+ obj ["payload" ] = ContextErrorFactory .serialize (self .payload )
230246
231- obj ["payload" ] = ContextErrorFactory .serialize (self .payload )
232247 torch .cuda .nvtx .range_pop ()
233248 return obj
234249
@@ -237,22 +252,25 @@ def deserialize(cls, obj: dict) -> "DynamicInferenceEvent":
237252 """Deserialize event.
238253
239254 Args:
240- obj (dict) : Serialized event data.
255+ obj: Serialized event data dict .
241256
242257 Returns:
243258 (DynamicInferenceEvent) Deserialized event.
244259 """
260+ event_type = DynamicInferenceEventType [obj ["type" ]]
245261
246- # Initialize event.
247- event = cls (** {** obj , "type" : DynamicInferenceEventType [obj ["type" ]]})
262+ # Pre-process payload before construction (since __post_init__ validates types).
263+ init_obj = {** obj , "type" : event_type }
264+ if obj ["payload" ] is not None :
265+ if event_type in (
266+ DynamicInferenceEventType .ERROR_TRANSIENT ,
267+ DynamicInferenceEventType .ERROR_NONTRANSIENT ,
268+ ):
269+ from .contexts .dynamic_context import ContextErrorFactory # avoid circular import.
248270
249- # Deserialize payload.
250- if obj ["payload" ]:
251- from .contexts .dynamic_context import ContextErrorFactory # avoid circular import.
271+ init_obj ["payload" ] = ContextErrorFactory .deserialize (obj ["payload" ])
252272
253- event .payload = ContextErrorFactory .deserialize (obj ["payload" ])
254-
255- return event
273+ return cls (** init_obj )
256274
257275
258276@experimental_api
@@ -265,7 +283,6 @@ class DynamicInferenceRequest(InferenceRequest):
265283 """
266284
267285 request_id : int
268- generated_tokens : List [int ] = field (default_factory = list )
269286 prompt : Optional [str ] = None
270287 prompt_tokens : Optional [torch .Tensor ] = None
271288 # remaining prompt tokens are used for chunked prefill
@@ -289,7 +306,10 @@ def remaining_prompt_length(self):
289306 """
290307 return len (self .remaining_prompt_tokens )
291308
309+ ttft : Optional [float ] = None
292310 events : List [DynamicInferenceEvent ] = field (default_factory = list )
311+ event_add_engine : Optional [DynamicInferenceEvent ] = field (default = None , repr = False )
312+ generated_tokens : List [int ] = field (default_factory = list )
293313
294314 def __str__ (self ):
295315 return ", " .join (
@@ -302,7 +322,7 @@ def __str__(self):
302322 )
303323 )
304324
305- def serialize (self ) -> dict :
325+ def serialize (self ):
306326 """Converts the instance into a serializable dictionary.
307327
308328 Returns:
@@ -312,6 +332,7 @@ def serialize(self) -> dict:
312332 torch .cuda .nvtx .range_push ("DynamicInferenceRequest.serialize" )
313333 obj = super ().serialize ()
314334 obj ["events" ] = [e .serialize () for e in self .events ]
335+ obj .pop ("event_add_engine" , None )
315336
316337 # Sanity check routing_indices: Tensor [total_tokens - 1, num_layers, topk]
317338 if self .routing_indices is not None :
@@ -328,7 +349,7 @@ def serialize(self) -> dict:
328349
329350 def _post_deserialize (self , obj ):
330351 super ()._post_deserialize (obj )
331- self .events = [DynamicInferenceEvent .deserialize (e ) for e in obj [ "events" ] ]
352+ self .events = [DynamicInferenceEvent .deserialize (e ) for e in obj . get ( "events" , []) ]
332353
333354 @property
334355 def tracked_metadata (self ) -> List [Any ]:
@@ -370,13 +391,30 @@ def get_metadata_types() -> List[Tuple[str, torch.dtype, bool]]:
370391 ("top_n_logprobs" , torch .int32 , False ), # CPU for torch sampling
371392 ]
372393
373- def add_event (self , type : DynamicInferenceEventType , payload : Optional [Any ] = None ) -> None :
394+ def add_event (
395+ self , type : DynamicInferenceEventType , payload : Optional [Any ] = None
396+ ) -> DynamicInferenceEvent :
374397 """Add event."""
375- self .events .append (DynamicInferenceEvent (type = type , payload = payload ))
398+ event = DynamicInferenceEvent (type = type , payload = payload )
399+ self .events .append (event )
400+ return event
401+
402+ def add_event_add_engine (self ):
403+ """Add 'add_engine' event - called when request enters the engine queue."""
404+ self .event_add_engine = self .add_event (DynamicInferenceEventType .ADD_ENGINE )
405+ return self .event_add_engine
376406
377- def add_event_add (self ):
378- """Add 'add' event."""
379- return self .add_event (DynamicInferenceEventType .ADD )
407+ def add_event_add_context (self ):
408+ """Add 'add_context' event - called when request is added to context for prefill."""
409+ return self .add_event (DynamicInferenceEventType .ADD_CONTEXT )
410+
411+ def add_event_generated_token (self , token : int ):
412+ """Add 'generated_token' event - records each generated token.
413+
414+ Args:
415+ token (int): The token ID that was generated.
416+ """
417+ return self .add_event (DynamicInferenceEventType .GENERATED_TOKEN , {"token_id" : token })
380418
381419 def add_event_pause (self ):
382420 """Add 'pause' event."""
@@ -535,6 +573,7 @@ def merge_lists(key):
535573 generated_log_probs = merge_lists ("generated_log_probs" ),
536574 generated_top_n_logprobs = merge_lists ("generated_top_n_logprobs" ),
537575 sampling_params = self .requests [0 ].sampling_params ,
576+ ttft = self .requests [0 ].ttft ,
538577 tpot = merge_lists ("tpot" ),
539578 status = self .requests [- 1 ].status ,
540579 latency = self .latency ,
0 commit comments