diff --git a/tensorrt_llm/_torch/pyexecutor/handle_logits.py b/tensorrt_llm/_torch/pyexecutor/handle_logits.py index 17c390735c9..233a207c2a4 100644 --- a/tensorrt_llm/_torch/pyexecutor/handle_logits.py +++ b/tensorrt_llm/_torch/pyexecutor/handle_logits.py @@ -83,8 +83,3 @@ def __call__( logits_view = logits[logits_begin:logits_end].reshape( 1, beam_width, -1) llm_req.py_result.append_generation_logits(logits_view) - - # Finalize any remaining logits transfers for all requests in chunked mode - for llm_req in chain(context_requests, generation_requests): - if llm_req.py_use_chunked_generation_logits and llm_req.py_return_generation_logits: - llm_req.py_result.transfer_remaining_device_logits() diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index 41299733639..98f73d17e5d 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -2637,6 +2637,12 @@ def _handle_responses(self): if request.return_perf_metrics and request.py_decoding_iter >= 1: request.update_perf_metrics(self.iter_counter) + if request.is_finished: + # Finalize any remaining logits transfers for the finished request in chunked mode + if request.py_use_chunked_generation_logits and request.py_return_generation_logits: + with torch.inference_mode(): + request.py_result.transfer_remaining_device_logits() + request_done = False if request.py_decoding_iter == 1 or request.is_finished or \ request.py_decoding_iter % self.stream_interval == 0: