Skip to content

Commit a51c1c8

Browse files
Add enhanced event tracking with TTFT measurement and compact serialization. (#3253)
Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
1 parent 47938af commit a51c1c8

File tree

8 files changed

+865
-49
lines changed

8 files changed

+865
-49
lines changed

examples/inference/gpt/gpt_dynamic_inference.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,8 @@ def _add_request():
195195
request.request_id = finished_request.request_id
196196
request.events = finished_request.events
197197

198+
request.ttft = finished_request.ttft
199+
198200
# Update prompt, in case engine has been suspended and resumed.
199201
request.prompt_tokens = finished_request.prompt_tokens.tolist()
200202
request.prompt_text = finished_request.prompt
@@ -409,6 +411,7 @@ def escape_str(s):
409411
"generated_text": req.output_text,
410412
"generated_tokens": req.output_tokens,
411413
"latency": req.time_end - req.time_start,
414+
"ttft": req.ttft, # Time-to-first-token in seconds
412415
"cuda_graph_request_count_map": result["cuda_graph_request_count_map"],
413416
"step_count": engine.step_count,
414417
"top_n_logprobs": getattr(req, 'generated_top_n_logprobs', None),

examples/inference/gpt/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ def __init__(
6969
self.time_arrival = None
7070
self.time_start = None
7171
self.time_end = None
72+
self.ttft = None # Time-to-first-token in seconds
7273
self.state = "not-started"
7374
self.sampling_params: SamplingParams = (
7475
sampling_params

megatron/core/inference/config.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,13 @@ class InferenceConfig:
170170
requests when they are paused during bookkeeping.
171171
"""
172172

173+
track_generated_token_events: bool = False
174+
"""
175+
Whether to track per-token events with timestamps for each generated token.
176+
When enabled, each generated token creates a GENERATED_TOKEN event with a
177+
timestamp, useful for per-token latency analysis.
178+
"""
179+
173180
metrics_writer: Optional["WandbModule"] = None
174181
"""Wandb module for writing metrics."""
175182

megatron/core/inference/engines/dynamic_engine.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@
3030
from megatron.core.inference.engines.abstract_engine import AbstractEngine
3131
from megatron.core.inference.headers import Headers, UnknownHeaderError
3232
from megatron.core.inference.inference_request import (
33+
DynamicInferenceEvent,
34+
DynamicInferenceEventType,
3335
DynamicInferenceRequest,
3436
DynamicInferenceRequestRecord,
3537
Status,
@@ -173,6 +175,7 @@ def __init__(self, controller: TextGenerationController, context: DynamicInferen
173175
self.controller = controller
174176
self.context = context
175177
self.track_paused_request_events = inference_config.track_paused_request_events
178+
self.track_generated_token_events = inference_config.track_generated_token_events
176179
self.enable_chunked_prefill = inference_config.enable_chunked_prefill
177180
self.metrics_writer = inference_config.metrics_writer
178181
self.logging_step_interval = inference_config.logging_step_interval
@@ -710,6 +713,7 @@ def _add_request(
710713
record=DynamicInferenceRequestRecord.from_request(request),
711714
future=self._loop.create_future(),
712715
)
716+
request.add_event_add_engine() # Record when request enters engine
713717

714718
if request.status is None:
715719
request.status = Status.ACTIVE_AND_GENERATING_TOKENS
@@ -882,7 +886,21 @@ def post_process_requests(
882886
# Skip appending token for requests being finished due to stop words
883887
# (they already have their final token from the previous step)
884888
if request_id not in self.stop_word_being_finished_ids:
889+
is_first_token = len(request.generated_tokens) == 0
885890
request.generated_tokens.append(token)
891+
if self.track_generated_token_events:
892+
event_generated_token = request.add_event_generated_token(token)
893+
if is_first_token:
894+
if self.track_generated_token_events:
895+
first_token_event = event_generated_token
896+
else:
897+
first_token_event = DynamicInferenceEvent(
898+
type=DynamicInferenceEventType.GENERATED_TOKEN,
899+
payload={"token_id": token},
900+
)
901+
request.ttft = (
902+
first_token_event.timestamp - request.event_add_engine.timestamp
903+
)
886904
if request.tpot is None:
887905
request.tpot = []
888906
request.tpot.append(step_time)
@@ -894,6 +912,7 @@ def post_process_requests(
894912
# Request finished by normal means (termination_id, max_length, or stop word from previous step)
895913
request.generated_length = len(request.generated_tokens)
896914
request.status = Status.COMPLETED
915+
request.add_event_finish()
897916
finished_entry = self.requests.pop(request_id)
898917
finished_request = finished_entry.record[-1]
899918
finished_request.generated_length = len(finished_request.generated_tokens)
@@ -1102,7 +1121,7 @@ def schedule_non_chunked_prefill(self):
11021121
self._loop.create_task, self._notify_cond_for_new_request()
11031122
)
11041123
req.remaining_prompt_tokens = req.remaining_prompt_tokens.new_empty(0)
1105-
req.add_event_add()
1124+
req.add_event_add_context()
11061125
self.waiting_request_ids.popleft()
11071126
else:
11081127
break
@@ -1148,7 +1167,7 @@ def schedule_chunked_prefill(self):
11481167
self._loop.create_task, self._notify_cond_for_new_request()
11491168
)
11501169
req.remaining_prompt_tokens = req.remaining_prompt_tokens.new_empty(0)
1151-
req.add_event_add()
1170+
req.add_event_add_context()
11521171
# Fully scheduled, so we remove from waiting pool
11531172
self.waiting_request_ids.popleft()
11541173
# Only this case we keep checking the rest of the waiting queue
@@ -1274,9 +1293,7 @@ async def async_bookkeep(
12741293
newly_paused_request_ids = newly_paused_request_ids.tolist()
12751294
[self.get_request(i).add_event_pause() for i in newly_paused_request_ids]
12761295

1277-
# Mark requests finished.
1278-
[self.get_request(i).add_event_finish() for i in finished_request_ids.tolist()]
1279-
# Add finished events.
1296+
# Process finished requests (adds FINISH events and returns records).
12801297
(active_request_ids, finished_request_records) = self.post_process_requests(
12811298
active_request_ids,
12821299
finished_request_ids,

megatron/core/inference/inference_request.py

Lines changed: 65 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,9 @@ def _post_deserialize(self, obj: dict):
159159
class DynamicInferenceEventType(Enum):
160160
"""Dynamic inference event type."""
161161

162-
ADD = auto()
162+
ADD_ENGINE = auto() # When request is added to engine via _add_request()
163+
ADD_CONTEXT = auto() # When request is added to context (scheduled for prefill)
164+
GENERATED_TOKEN = auto() # When an output token is generated (payload = {"token_id": int})
163165
PAUSE = auto()
164166
EVICT = auto()
165167
FINISH = auto()
@@ -202,33 +204,46 @@ def __post_init__(self):
202204
DynamicInferenceEventType.ERROR_NONTRANSIENT,
203205
):
204206
assert self.payload is not None
207+
elif self.type == DynamicInferenceEventType.GENERATED_TOKEN:
208+
assert (
209+
self.payload is not None
210+
and isinstance(self.payload, dict)
211+
and "token_id" in self.payload
212+
)
205213
else:
206214
assert self.payload is None
207215

208216
def __str__(self):
209-
payload_str = "" if self.payload is None else f", {type(self.payload).__name__}"
217+
if self.type == DynamicInferenceEventType.GENERATED_TOKEN:
218+
payload_str = f", token={self.payload['token_id']}"
219+
elif self.payload is None:
220+
payload_str = ""
221+
else:
222+
payload_str = f", {type(self.payload).__name__}"
210223
return f"[{self.timestamp:.3f}] {self.type.name}{payload_str}"
211224

212225
def serialize(self) -> dict:
213226
"""Converts the instance into a serializable dictionary.
214227
215228
Returns:
216-
(dict) A dictionary representation of the instance suitable for
217-
serialization.
229+
dict: Full event dict.
218230
"""
219-
220-
# Dataclass to dict.
221231
torch.cuda.nvtx.range_push("DynamicInferenceEvent.serialize")
222232
# do not use asdict(self) - it has very high CPU overheads
223233
# and if there are tensors, it will try to deepcopy them
224234
obj = self.__dict__.copy()
225235
obj["type"] = self.type.name
226236

227237
# Serialize payload.
228-
if self.payload:
229-
from .contexts.dynamic_context import ContextErrorFactory # avoid circular import.
238+
if self.payload is not None:
239+
if self.type in (
240+
DynamicInferenceEventType.ERROR_TRANSIENT,
241+
DynamicInferenceEventType.ERROR_NONTRANSIENT,
242+
):
243+
from .contexts.dynamic_context import ContextErrorFactory # avoid circular import.
244+
245+
obj["payload"] = ContextErrorFactory.serialize(self.payload)
230246

231-
obj["payload"] = ContextErrorFactory.serialize(self.payload)
232247
torch.cuda.nvtx.range_pop()
233248
return obj
234249

@@ -237,22 +252,25 @@ def deserialize(cls, obj: dict) -> "DynamicInferenceEvent":
237252
"""Deserialize event.
238253
239254
Args:
240-
obj (dict): Serialized event data.
255+
obj: Serialized event data dict.
241256
242257
Returns:
243258
(DynamicInferenceEvent) Deserialized event.
244259
"""
260+
event_type = DynamicInferenceEventType[obj["type"]]
245261

246-
# Initialize event.
247-
event = cls(**{**obj, "type": DynamicInferenceEventType[obj["type"]]})
262+
# Pre-process payload before construction (since __post_init__ validates types).
263+
init_obj = {**obj, "type": event_type}
264+
if obj["payload"] is not None:
265+
if event_type in (
266+
DynamicInferenceEventType.ERROR_TRANSIENT,
267+
DynamicInferenceEventType.ERROR_NONTRANSIENT,
268+
):
269+
from .contexts.dynamic_context import ContextErrorFactory # avoid circular import.
248270

249-
# Deserialize payload.
250-
if obj["payload"]:
251-
from .contexts.dynamic_context import ContextErrorFactory # avoid circular import.
271+
init_obj["payload"] = ContextErrorFactory.deserialize(obj["payload"])
252272

253-
event.payload = ContextErrorFactory.deserialize(obj["payload"])
254-
255-
return event
273+
return cls(**init_obj)
256274

257275

258276
@experimental_api
@@ -265,7 +283,6 @@ class DynamicInferenceRequest(InferenceRequest):
265283
"""
266284

267285
request_id: int
268-
generated_tokens: List[int] = field(default_factory=list)
269286
prompt: Optional[str] = None
270287
prompt_tokens: Optional[torch.Tensor] = None
271288
# remaining prompt tokens are used for chunked prefill
@@ -289,7 +306,10 @@ def remaining_prompt_length(self):
289306
"""
290307
return len(self.remaining_prompt_tokens)
291308

309+
ttft: Optional[float] = None
292310
events: List[DynamicInferenceEvent] = field(default_factory=list)
311+
event_add_engine: Optional[DynamicInferenceEvent] = field(default=None, repr=False)
312+
generated_tokens: List[int] = field(default_factory=list)
293313

294314
def __str__(self):
295315
return ", ".join(
@@ -302,7 +322,7 @@ def __str__(self):
302322
)
303323
)
304324

305-
def serialize(self) -> dict:
325+
def serialize(self):
306326
"""Converts the instance into a serializable dictionary.
307327
308328
Returns:
@@ -312,6 +332,7 @@ def serialize(self) -> dict:
312332
torch.cuda.nvtx.range_push("DynamicInferenceRequest.serialize")
313333
obj = super().serialize()
314334
obj["events"] = [e.serialize() for e in self.events]
335+
obj.pop("event_add_engine", None)
315336

316337
# Sanity check routing_indices: Tensor [total_tokens - 1, num_layers, topk]
317338
if self.routing_indices is not None:
@@ -328,7 +349,7 @@ def serialize(self) -> dict:
328349

329350
def _post_deserialize(self, obj):
330351
super()._post_deserialize(obj)
331-
self.events = [DynamicInferenceEvent.deserialize(e) for e in obj["events"]]
352+
self.events = [DynamicInferenceEvent.deserialize(e) for e in obj.get("events", [])]
332353

333354
@property
334355
def tracked_metadata(self) -> List[Any]:
@@ -370,13 +391,30 @@ def get_metadata_types() -> List[Tuple[str, torch.dtype, bool]]:
370391
("top_n_logprobs", torch.int32, False), # CPU for torch sampling
371392
]
372393

373-
def add_event(self, type: DynamicInferenceEventType, payload: Optional[Any] = None) -> None:
394+
def add_event(
395+
self, type: DynamicInferenceEventType, payload: Optional[Any] = None
396+
) -> DynamicInferenceEvent:
374397
"""Add event."""
375-
self.events.append(DynamicInferenceEvent(type=type, payload=payload))
398+
event = DynamicInferenceEvent(type=type, payload=payload)
399+
self.events.append(event)
400+
return event
401+
402+
def add_event_add_engine(self):
403+
"""Add 'add_engine' event - called when request enters the engine queue."""
404+
self.event_add_engine = self.add_event(DynamicInferenceEventType.ADD_ENGINE)
405+
return self.event_add_engine
376406

377-
def add_event_add(self):
378-
"""Add 'add' event."""
379-
return self.add_event(DynamicInferenceEventType.ADD)
407+
def add_event_add_context(self):
408+
"""Add 'add_context' event - called when request is added to context for prefill."""
409+
return self.add_event(DynamicInferenceEventType.ADD_CONTEXT)
410+
411+
def add_event_generated_token(self, token: int):
412+
"""Add 'generated_token' event - records each generated token.
413+
414+
Args:
415+
token (int): The token ID that was generated.
416+
"""
417+
return self.add_event(DynamicInferenceEventType.GENERATED_TOKEN, {"token_id": token})
380418

381419
def add_event_pause(self):
382420
"""Add 'pause' event."""
@@ -535,6 +573,7 @@ def merge_lists(key):
535573
generated_log_probs=merge_lists("generated_log_probs"),
536574
generated_top_n_logprobs=merge_lists("generated_top_n_logprobs"),
537575
sampling_params=self.requests[0].sampling_params,
576+
ttft=self.requests[0].ttft,
538577
tpot=merge_lists("tpot"),
539578
status=self.requests[-1].status,
540579
latency=self.latency,

megatron/training/arguments.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1593,6 +1593,11 @@ def _add_inference_args(parser):
15931593
help='Track paused request ids by adding \'paused\' events '
15941594
'to each request\'s event history. This has a very minor '
15951595
'impact on latency.')
1596+
group.add_argument('--inference-dynamic-batching-track-generated-token-events',
1597+
action='store_true',
1598+
help='Track per-token events with timestamps for each generated token. '
1599+
'When enabled, each generated token creates a GENERATED_TOKEN event '
1600+
'with a timestamp, useful for per-token latency analysis.')
15961601
group.add_argument('--decode-only-cuda-graphs',
15971602
action='store_true', default=False,
15981603
help='Only use cuda graphs for decode-only steps, not prefill and mixed steps.')

0 commit comments

Comments
 (0)