|
21 | 21 | from tensorrt_llm._utils import global_mpi_rank, nvtx_range |
22 | 22 | from tensorrt_llm.bindings.executor import (FinishReason, InflightBatchingStats, |
23 | 23 | IterationStats, KvCacheStats, |
24 | | - RequestType) |
| 24 | + RequestType, StaticBatchingStats) |
25 | 25 | from tensorrt_llm.bindings.internal.batch_manager import ReqIdsSet |
26 | 26 | from tensorrt_llm.logger import logger |
27 | 27 |
|
@@ -494,11 +494,15 @@ def profile_step(): |
494 | 494 | def _get_init_iter_stats(self, num_new_active_requests, |
495 | 495 | new_active_requests_queue_latency_ms): |
496 | 496 | stats = IterationStats() |
497 | | - stats.timestamp = "" |
| 497 | + stats.timestamp = datetime.datetime.now().strftime( |
| 498 | + "%m-%d-%Y %H:%M:%S.%f") |
498 | 499 |
|
499 | 500 | stats.num_new_active_requests = num_new_active_requests |
500 | 501 | stats.num_active_requests = len(self.active_requests) |
501 | 502 | stats.new_active_requests_queue_latency_ms = new_active_requests_queue_latency_ms |
| 503 | + stats.inflight_batching_stats = InflightBatchingStats() |
| 504 | + # staticBatchingStats is not used in pytorch path |
| 505 | + stats.static_batching_stats = StaticBatchingStats() |
502 | 506 | return stats |
503 | 507 |
|
504 | 508 | def _update_iter_stats(self, stats, iter_latency_ms, num_completed_requests, |
@@ -532,17 +536,17 @@ def _update_iter_stats(self, stats, iter_latency_ms, num_completed_requests, |
532 | 536 | kv_stats_to_save.cache_hit_rate = kv_stats.cache_hit_rate |
533 | 537 | stats.kv_cache_stats = kv_stats_to_save |
534 | 538 |
|
535 | | - model_stats = InflightBatchingStats() |
536 | | - model_stats.num_scheduled_requests = len( |
| 539 | + stats.inflight_batching_stats.num_scheduled_requests = len( |
537 | 540 | scheduled_batch.context_requests) + len( |
538 | 541 | scheduled_batch.generation_requests) |
539 | | - model_stats.num_context_requests = len(scheduled_batch.context_requests) |
540 | | - model_stats.num_gen_requests = len(scheduled_batch.generation_requests) |
541 | | - model_stats.num_paused_requests = len(scheduled_batch.paused_requests) |
542 | | - model_stats.avg_num_decoded_tokens_per_iter = 0 |
543 | | - model_stats.num_ctx_tokens = 0 |
544 | | - model_stats.micro_batch_id = 0 |
545 | | - stats.inflight_batching_stats = model_stats |
| 542 | + stats.inflight_batching_stats.num_context_requests = len( |
| 543 | + scheduled_batch.context_requests) |
| 544 | + stats.inflight_batching_stats.num_gen_requests = len( |
| 545 | + scheduled_batch.generation_requests) |
| 546 | + stats.inflight_batching_stats.num_paused_requests = len( |
| 547 | + scheduled_batch.paused_requests) |
| 548 | + stats.inflight_batching_stats.avg_num_decoded_tokens_per_iter = 0 |
| 549 | + stats.inflight_batching_stats.micro_batch_id = 0 |
546 | 550 | return stats |
547 | 551 |
|
548 | 552 | def _append_iter_stats(self, stats): |
@@ -624,6 +628,10 @@ def _executor_loop_pp(self): |
624 | 628 | decoder_state = self._forward_step_inter_pp( |
625 | 629 | scheduled_batch) |
626 | 630 |
|
| 631 | + if self.enable_iter_perf_stats: |
| 632 | + iter_stats.inflight_batching_stats.num_ctx_tokens = self.model_engine.iter_states[ |
| 633 | + 'num_ctx_tokens'] |
| 634 | + |
627 | 635 | batch_state = BatchStatePP( |
628 | 636 | decoder_state=decoder_state, |
629 | 637 | iter_start_time=iter_start_time, |
@@ -717,6 +725,9 @@ def _executor_loop_pp_overlap(self): |
717 | 725 | scheduled_batch, batch_outputs) |
718 | 726 | self._update_request_states(scheduled_batch) |
719 | 727 |
|
| 728 | + if self.enable_iter_perf_stats: |
| 729 | + iter_stats.inflight_batching_stats.num_ctx_tokens = self.model_engine.iter_states[ |
| 730 | + 'num_ctx_tokens'] |
720 | 731 | batch_state = BatchStatePP( |
721 | 732 | decoder_state=decoder_state, |
722 | 733 | iter_start_time=iter_start_time, |
@@ -887,6 +898,8 @@ def _executor_loop(self): |
887 | 898 | self._gather_dp_requests_num() |
888 | 899 |
|
889 | 900 | if self.enable_iter_perf_stats: |
| 901 | + iter_stats.inflight_batching_stats.num_ctx_tokens = self.model_engine.iter_states[ |
| 902 | + 'num_ctx_tokens'] |
890 | 903 | self._process_iter_stats( |
891 | 904 | finished_requests, |
892 | 905 | BatchState(decoder_state=DecoderState( |
@@ -1040,6 +1053,10 @@ def _executor_loop_overlap(self): |
1040 | 1053 | if r.get_context_remaining_length() == 0 |
1041 | 1054 | ] |
1042 | 1055 |
|
| 1056 | + if self.enable_iter_perf_stats: |
| 1057 | + iter_stats.inflight_batching_stats.num_ctx_tokens = self.model_engine.iter_states[ |
| 1058 | + 'num_ctx_tokens'] |
| 1059 | + |
1043 | 1060 | self.previous_batch = BatchState( |
1044 | 1061 | decoder_state=decoder_state, |
1045 | 1062 | iter_start_time=iter_start_time, |
|
0 commit comments