Skip to content

Commit 136aab5

Browse files
authored
fix: Update num_of_ctx_tokens in iteration stats (NVIDIA#3785)
* Update num_of_ctx_tokens in iteration stats * Revert not neccessary change of importing module
1 parent a4b483b commit 136aab5

File tree

1 file changed

+28
-11
lines changed

1 file changed

+28
-11
lines changed

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 28 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from tensorrt_llm._utils import global_mpi_rank, nvtx_range
2222
from tensorrt_llm.bindings.executor import (FinishReason, InflightBatchingStats,
2323
IterationStats, KvCacheStats,
24-
RequestType)
24+
RequestType, StaticBatchingStats)
2525
from tensorrt_llm.bindings.internal.batch_manager import ReqIdsSet
2626
from tensorrt_llm.logger import logger
2727

@@ -494,11 +494,15 @@ def profile_step():
494494
def _get_init_iter_stats(self, num_new_active_requests,
495495
new_active_requests_queue_latency_ms):
496496
stats = IterationStats()
497-
stats.timestamp = ""
497+
stats.timestamp = datetime.datetime.now().strftime(
498+
"%m-%d-%Y %H:%M:%S.%f")
498499

499500
stats.num_new_active_requests = num_new_active_requests
500501
stats.num_active_requests = len(self.active_requests)
501502
stats.new_active_requests_queue_latency_ms = new_active_requests_queue_latency_ms
503+
stats.inflight_batching_stats = InflightBatchingStats()
504+
# staticBatchingStats is not used in pytorch path
505+
stats.static_batching_stats = StaticBatchingStats()
502506
return stats
503507

504508
def _update_iter_stats(self, stats, iter_latency_ms, num_completed_requests,
@@ -532,17 +536,17 @@ def _update_iter_stats(self, stats, iter_latency_ms, num_completed_requests,
532536
kv_stats_to_save.cache_hit_rate = kv_stats.cache_hit_rate
533537
stats.kv_cache_stats = kv_stats_to_save
534538

535-
model_stats = InflightBatchingStats()
536-
model_stats.num_scheduled_requests = len(
539+
stats.inflight_batching_stats.num_scheduled_requests = len(
537540
scheduled_batch.context_requests) + len(
538541
scheduled_batch.generation_requests)
539-
model_stats.num_context_requests = len(scheduled_batch.context_requests)
540-
model_stats.num_gen_requests = len(scheduled_batch.generation_requests)
541-
model_stats.num_paused_requests = len(scheduled_batch.paused_requests)
542-
model_stats.avg_num_decoded_tokens_per_iter = 0
543-
model_stats.num_ctx_tokens = 0
544-
model_stats.micro_batch_id = 0
545-
stats.inflight_batching_stats = model_stats
542+
stats.inflight_batching_stats.num_context_requests = len(
543+
scheduled_batch.context_requests)
544+
stats.inflight_batching_stats.num_gen_requests = len(
545+
scheduled_batch.generation_requests)
546+
stats.inflight_batching_stats.num_paused_requests = len(
547+
scheduled_batch.paused_requests)
548+
stats.inflight_batching_stats.avg_num_decoded_tokens_per_iter = 0
549+
stats.inflight_batching_stats.micro_batch_id = 0
546550
return stats
547551

548552
def _append_iter_stats(self, stats):
@@ -624,6 +628,10 @@ def _executor_loop_pp(self):
624628
decoder_state = self._forward_step_inter_pp(
625629
scheduled_batch)
626630

631+
if self.enable_iter_perf_stats:
632+
iter_stats.inflight_batching_stats.num_ctx_tokens = self.model_engine.iter_states[
633+
'num_ctx_tokens']
634+
627635
batch_state = BatchStatePP(
628636
decoder_state=decoder_state,
629637
iter_start_time=iter_start_time,
@@ -717,6 +725,9 @@ def _executor_loop_pp_overlap(self):
717725
scheduled_batch, batch_outputs)
718726
self._update_request_states(scheduled_batch)
719727

728+
if self.enable_iter_perf_stats:
729+
iter_stats.inflight_batching_stats.num_ctx_tokens = self.model_engine.iter_states[
730+
'num_ctx_tokens']
720731
batch_state = BatchStatePP(
721732
decoder_state=decoder_state,
722733
iter_start_time=iter_start_time,
@@ -887,6 +898,8 @@ def _executor_loop(self):
887898
self._gather_dp_requests_num()
888899

889900
if self.enable_iter_perf_stats:
901+
iter_stats.inflight_batching_stats.num_ctx_tokens = self.model_engine.iter_states[
902+
'num_ctx_tokens']
890903
self._process_iter_stats(
891904
finished_requests,
892905
BatchState(decoder_state=DecoderState(
@@ -1040,6 +1053,10 @@ def _executor_loop_overlap(self):
10401053
if r.get_context_remaining_length() == 0
10411054
]
10421055

1056+
if self.enable_iter_perf_stats:
1057+
iter_stats.inflight_batching_stats.num_ctx_tokens = self.model_engine.iter_states[
1058+
'num_ctx_tokens']
1059+
10431060
self.previous_batch = BatchState(
10441061
decoder_state=decoder_state,
10451062
iter_start_time=iter_start_time,

0 commit comments

Comments
 (0)