diff --git a/cpp/include/tensorrt_llm/batch_manager/llmRequest.h b/cpp/include/tensorrt_llm/batch_manager/llmRequest.h index 276b0a6483b..1847c7fcb2e 100644 --- a/cpp/include/tensorrt_llm/batch_manager/llmRequest.h +++ b/cpp/include/tensorrt_llm/batch_manager/llmRequest.h @@ -1667,6 +1667,16 @@ class GenericLlmRequest [](auto reason) { return reason == executor::FinishReason::kLENGTH; }); } + [[nodiscard]] bool isFinishedNormal() const noexcept + { + return std::all_of(mFinishReasons.begin(), mFinishReasons.end(), + [](auto reason) + { + return reason == executor::FinishReason::kEND_ID || reason == executor::FinishReason::kSTOP_WORDS + || reason == executor::FinishReason::kLENGTH; + }); + } + [[nodiscard]] bool isTimedOut() const { if (!mAllottedTimeMs.has_value()) diff --git a/cpp/tensorrt_llm/nanobind/batch_manager/bindings.cpp b/cpp/tensorrt_llm/nanobind/batch_manager/bindings.cpp index 17c27f43bed..9187c432428 100644 --- a/cpp/tensorrt_llm/nanobind/batch_manager/bindings.cpp +++ b/cpp/tensorrt_llm/nanobind/batch_manager/bindings.cpp @@ -161,6 +161,7 @@ void initBindings(nb::module_& m) .def("set_finished_reason", &GenLlmReq::setFinishedReason, nb::arg("finish_reason"), nb::arg("beam")) .def_prop_ro("is_finished", &GenLlmReq::isFinished) .def_prop_ro("is_finished_due_to_length", &GenLlmReq::isFinishedDueToLength) + .def_prop_ro("is_finished_normal", &GenLlmReq::isFinishedNormal) .def_prop_rw( "context_current_position", &GenLlmReq::getContextCurrentPosition, &GenLlmReq::setContextCurrentPosition) .def_prop_ro("prepopulated_prompt_len", &GenLlmReq::getPrepopulatedPromptLen) diff --git a/cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp b/cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp index 1d98b0c623a..c6714d3586d 100644 --- a/cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp +++ b/cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp @@ -165,6 +165,7 @@ void initBindings(pybind11::module_& m) .def("set_finished_reason", &GenLlmReq::setFinishedReason, py::arg("finish_reason"), py::arg("beam")) .def_property_readonly("is_finished", &GenLlmReq::isFinished) .def_property_readonly("is_finished_due_to_length", &GenLlmReq::isFinishedDueToLength) + .def_property_readonly("is_finished_normal", &GenLlmReq::isFinishedNormal) .def_property( "context_current_position", &GenLlmReq::getContextCurrentPosition, &GenLlmReq::setContextCurrentPosition) .def_property_readonly("prepopulated_prompt_len", &GenLlmReq::getPrepopulatedPromptLen) diff --git a/tensorrt_llm/_torch/pyexecutor/kv_cache_connector.py b/tensorrt_llm/_torch/pyexecutor/kv_cache_connector.py index 59f7256c98a..b4cc9d9734a 100644 --- a/tensorrt_llm/_torch/pyexecutor/kv_cache_connector.py +++ b/tensorrt_llm/_torch/pyexecutor/kv_cache_connector.py @@ -234,6 +234,13 @@ def update_state_after_alloc(self, request: LlmRequest, block_ids: The KV cacheblock IDs that were allocated. """ + def wait_for_initialization(self): + """ + Some connectors need to wait for some resources to be initialized. + For example, FlexKV needs to wait for the FlexKV manager to be initialized. + """ + return + # An internal dataclass to handle async saving/loading requests. @dataclass @@ -570,3 +577,7 @@ def layer_pre_hook(self, module, *args): def layer_post_hook(self, module, *args): self.worker.save_kv_layer(module.layer_idx, torch.cuda.current_stream()) + + def wait_for_initialization(self): + if self.scheduler is not None: + self.scheduler.wait_for_initialization() diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index 8f54aca6c48..bc80d61750d 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -358,6 +358,8 @@ def _maybe_init_kv_connector_manager(self): module.register_forward_hook( self.kv_connector_manager.layer_post_hook) + self.kv_connector_manager.wait_for_initialization() + def _event_loop_wrapper(self): try: with customized_gc_thresholds( @@ -610,7 +612,7 @@ def profile_step(): if prev_device_step_time is None: prev_device_step_time = "N/A" # Handle first iteration else: - prev_device_step_time = f"{prev_device_step_time}ms" + prev_device_step_time = f"{prev_device_step_time} ms" host_step_time = (end_time - start_time) * 1000 # milliseconds formatted_timestamp = datetime.datetime.now().strftime( "%Y-%m-%d %H:%M:%S") @@ -620,7 +622,7 @@ def profile_step(): f"rank = {self.dist.rank}, " f"currank_total_requests = {self.executor_request_queue.num_fetch_requests_cur_rank}/" f"{self.executor_request_queue.num_fetch_requests}, " - f"host_step_time = {host_step_time}ms, " + f"host_step_time = {host_step_time:.3f} ms, " f"prev_device_step_time = {prev_device_step_time}, " f"timestamp = {formatted_timestamp}, " f"num_scheduled_requests: {self.num_scheduled_requests}, " @@ -1302,7 +1304,6 @@ def _kv_connector_start_batch(self, scheduled_batch): if self.kv_connector_manager: self.kv_connector_manager.take_scheduled_requests_pending_load( scheduled_batch) - self.kv_connector_manager.handle_metadata() self.kv_connector_manager.worker.start_load_kv( torch.cuda.current_stream()) @@ -1348,6 +1349,10 @@ def _executor_loop(self): finished_requests = [] can_queue = self._can_queue(scheduled_batch) + + if self.kv_connector_manager: + self.kv_connector_manager.handle_metadata() + if can_queue: if self.kv_cache_transceiver: # For generation requests which have completed KV cache transfer @@ -1578,6 +1583,9 @@ def _executor_loop_overlap(self): self._pause_requests(scheduled_batch.paused_requests) + if self.kv_connector_manager: + self.kv_connector_manager.handle_metadata() + can_queue = self._can_queue(scheduled_batch) if can_queue: if self.kv_cache_transceiver: @@ -2634,6 +2642,9 @@ def _handle_responses(self): self.ctx_in_transmission_counter)) else: requests_to_terminate.append(request) + + if self.kv_connector_manager is not None: + self.resource_manager.free_slot_only(request) else: new_active_requests.append(request) diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py b/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py index bd1857dda27..4eeb3b22bf4 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py @@ -177,6 +177,7 @@ def _get_mapping(_mapping: Mapping) -> Mapping: tp_size=tensorrt_llm.mpi_world_size(), gpus_per_node=tensorrt_llm.default_gpus_per_node(), rank=tensorrt_llm.mpi_rank()) + executor_config.mapping = mapping else: mapping = copy.deepcopy(_mapping) mapping.rank = tensorrt_llm.mpi_rank() diff --git a/tensorrt_llm/_torch/pyexecutor/resource_manager.py b/tensorrt_llm/_torch/pyexecutor/resource_manager.py index 619f8525c17..622dda558e5 100644 --- a/tensorrt_llm/_torch/pyexecutor/resource_manager.py +++ b/tensorrt_llm/_torch/pyexecutor/resource_manager.py @@ -1441,6 +1441,16 @@ def reorder_pipeline(self, for resource_manager in resource_manager_list: self.resource_managers.move_to_end(resource_manager) + def free_slot_only(self, request: LlmRequest): + """Only free the slot for the request, without freeing other resources. + This is used to release the slot early when decode finishes, before + the put task completes. + """ + seq_slot_manager = self.get_resource_manager( + ResourceManagerType.SEQ_SLOT_MANAGER) + if seq_slot_manager is not None: + seq_slot_manager.free_resources(request) + class PeftCacheManager(BaseResourceManager):