diff --git a/cpp/include/tensorrt_llm/batch_manager/cacheTransceiver.h b/cpp/include/tensorrt_llm/batch_manager/cacheTransceiver.h index de68e9805ed..4da26f72d59 100644 --- a/cpp/include/tensorrt_llm/batch_manager/cacheTransceiver.h +++ b/cpp/include/tensorrt_llm/batch_manager/cacheTransceiver.h @@ -190,6 +190,14 @@ class CacheTransceiverFactory std::optional cacheTransceiverConfig = std::nullopt); }; +struct RequestStatuses +{ + /// Requests that have completed their transfer successfully. + std::unordered_set completedRequestIds; + /// Requests that have encountered an error during their transfer. + std::unordered_set errorRequestIds; +}; + class BaseCacheTransceiver { public: @@ -202,7 +210,10 @@ class BaseCacheTransceiver virtual void requestAndReceiveSync(LlmRequest* llmRequest) = 0; virtual void requestAndReceiveAsync(LlmRequest* llmRequest) = 0; - virtual void checkContextTransferStatus(std::optional const& atLeastRequestNum = std::nullopt) = 0; + /// Check all requests transferring context, and return the requests that have completed or encountered an error. + virtual RequestStatuses checkContextTransferStatus( + std::optional const& atLeastRequestNum = std::nullopt, bool markComplete = false) + = 0; virtual void checkGenTransferStatus(std::optional const& atLeastRequestNum = std::nullopt) = 0; @@ -243,7 +254,8 @@ class CacheTransceiver : public BaseCacheTransceiver void requestAndReceiveSync(LlmRequest* llmRequest) override; void requestAndReceiveAsync(LlmRequest* llmRequest) override; - void checkContextTransferStatus(std::optional const& atLeastRequestNum = std::nullopt) override; + RequestStatuses checkContextTransferStatus( + std::optional const& atLeastRequestNum = std::nullopt, bool markComplete = false) override; void checkGenTransferStatus(std::optional const& atLeastRequestNum = std::nullopt) override; diff --git a/cpp/tensorrt_llm/batch_manager/cacheTransceiver.cpp b/cpp/tensorrt_llm/batch_manager/cacheTransceiver.cpp index 7e4c26bfd78..2170370d551 100644 --- a/cpp/tensorrt_llm/batch_manager/cacheTransceiver.cpp +++ b/cpp/tensorrt_llm/batch_manager/cacheTransceiver.cpp @@ -427,7 +427,8 @@ void updateKVCacheTransferBW(std::shared_ptr const& mComm, } } -void CacheTransceiver::checkContextTransferStatus(std::optional const& atLeastRequestNum) +RequestStatuses CacheTransceiver::checkContextTransferStatus( + std::optional const& atLeastRequestNum, bool markComplete) { bool blockAll = !atLeastRequestNum.has_value(); std::optional senderFutureTimeoutMs = std::nullopt; @@ -486,6 +487,8 @@ void CacheTransceiver::checkContextTransferStatus(std::optional const& atLe toCompleteIdSet.insert(request->mRequestId); } + RequestStatuses requestsStatus{}; + // Complete all the requests in toCompleteIdSet for (auto it = mSenderFutures.begin(); it != mSenderFutures.end();) { @@ -499,7 +502,11 @@ void CacheTransceiver::checkContextTransferStatus(std::optional const& atLe if (status == std::future_status::ready || !senderFutureTimeoutMs.has_value()) { future.get(); - request->setState(LlmRequestState::kDISAGG_CONTEXT_COMPLETE); + requestsStatus.completedRequestIds.insert(request->mRequestId); + if (markComplete) + { + request->setState(LlmRequestState::kDISAGG_CONTEXT_COMPLETE); + } it = mSenderFutures.erase(it); } else if (status == std::future_status::timeout) @@ -514,6 +521,7 @@ void CacheTransceiver::checkContextTransferStatus(std::optional const& atLe "Future returned unexpected status for request %ld. Marking as error", request->mRequestId); request->setState(LlmRequestState::kDISAGG_TRANS_ERROR); + requestsStatus.errorRequestIds.insert(request->mRequestId); it = mSenderFutures.erase(it); } } @@ -522,6 +530,7 @@ void CacheTransceiver::checkContextTransferStatus(std::optional const& atLe TLLM_LOG_ERROR( "Error occurred during context transfer for request %ld: %s", request->mRequestId, e.what()); request->setState(LlmRequestState::kDISAGG_TRANS_ERROR); + requestsStatus.errorRequestIds.insert(request->mRequestId); it = mSenderFutures.erase(it); } } @@ -530,6 +539,8 @@ void CacheTransceiver::checkContextTransferStatus(std::optional const& atLe ++it; } } + + return requestsStatus; } void CacheTransceiver::checkGenTransferStatus(std::optional const& atLeastRequestNum) diff --git a/cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.cpp b/cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.cpp index 9210fe95874..e93a908aa8f 100644 --- a/cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.cpp +++ b/cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.cpp @@ -503,7 +503,7 @@ TrtGptModelInflightBatching::~TrtGptModelInflightBatching() { if (mCacheTransceiver) { - mCacheTransceiver->checkContextTransferStatus(true); + mCacheTransceiver->checkContextTransferStatus(1, true); TLLM_CHECK_WITH_INFO(mCacheTransceiver->checkGenTransferComplete(), "Generation transfer not complete"); } if (mAsyncSendWaitThread) @@ -932,7 +932,7 @@ void TrtGptModelInflightBatching::forwardSync() } if (mCacheTransceiver) { - mCacheTransceiver->checkContextTransferStatus(0); + mCacheTransceiver->checkContextTransferStatus(0, true); } ++mIterCounter; @@ -1025,7 +1025,7 @@ void TrtGptModelInflightBatching::forwardAsync(RequestList const& activeRequests mIterCounter); if (mCacheTransceiver) { - mCacheTransceiver->checkContextTransferStatus(1); + mCacheTransceiver->checkContextTransferStatus(1, true); // will free kvCache in next iteration. } } diff --git a/cpp/tensorrt_llm/nanobind/batch_manager/cacheTransceiver.cpp b/cpp/tensorrt_llm/nanobind/batch_manager/cacheTransceiver.cpp index c018a1e0d1a..dd3452f0f6c 100644 --- a/cpp/tensorrt_llm/nanobind/batch_manager/cacheTransceiver.cpp +++ b/cpp/tensorrt_llm/nanobind/batch_manager/cacheTransceiver.cpp @@ -60,9 +60,10 @@ class PyCacheTransceiver : public tb::BaseCacheTransceiver NB_OVERRIDE_PURE(requestAndReceiveAsync, llmRequest); } - void checkContextTransferStatus(std::optional const& atLeastRequestNum = std::nullopt) override + tb::RequestStatuses checkContextTransferStatus( + std::optional const& atLeastRequestNum = std::nullopt, bool markComplete = false) override { - NB_OVERRIDE_PURE(checkContextTransferStatus, atLeastRequestNum); + NB_OVERRIDE_PURE(checkContextTransferStatus, atLeastRequestNum, markComplete); } void checkGenTransferStatus(std::optional const& atLeastRequestNum = std::nullopt) override @@ -88,8 +89,23 @@ void tb::CacheTransceiverBindings::initBindings(nb::module_& m) .def("respond_and_send_async", &BaseCacheTransceiver::respondAndSendAsync) .def("request_and_receive_sync", &BaseCacheTransceiver::requestAndReceiveSync) .def("request_and_receive_async", &BaseCacheTransceiver::requestAndReceiveAsync) - .def("check_context_transfer_status", &BaseCacheTransceiver::checkContextTransferStatus, - nb::call_guard()) + .def( + "check_context_transfer_status", + [](tb::BaseCacheTransceiver& self, std::optional const& atLeastRequestNum, bool markComplete = false) + { + RequestStatuses result; + { + nb::gil_scoped_release release; + result = self.checkContextTransferStatus(atLeastRequestNum, markComplete); + } + + auto completedRequestIds + = std::vector(result.completedRequestIds.begin(), result.completedRequestIds.end()); + auto errorRequestIds + = std::vector(result.errorRequestIds.begin(), result.errorRequestIds.end()); + return nb::make_tuple(completedRequestIds, errorRequestIds); + }, + nb::arg("at_least_request_num") = std::nullopt, nb::arg("mark_complete") = false) .def("check_gen_transfer_status", &BaseCacheTransceiver::checkGenTransferStatus, nb::call_guard()) .def("check_gen_transfer_complete", &BaseCacheTransceiver::checkGenTransferComplete) diff --git a/cpp/tensorrt_llm/pybind/batch_manager/cacheTransceiver.cpp b/cpp/tensorrt_llm/pybind/batch_manager/cacheTransceiver.cpp index 30bf411c9b1..7ab2ba0241d 100644 --- a/cpp/tensorrt_llm/pybind/batch_manager/cacheTransceiver.cpp +++ b/cpp/tensorrt_llm/pybind/batch_manager/cacheTransceiver.cpp @@ -56,9 +56,13 @@ class PyCacheTransceiver : public tb::BaseCacheTransceiver PYBIND11_OVERLOAD_PURE(void, tb::BaseCacheTransceiver, requestAndReceiveAsync, llmRequest); } - void checkContextTransferStatus(std::optional const& atLeastRequestNum = std::nullopt) override + using RequestStatuses = tb::RequestStatuses; + + RequestStatuses checkContextTransferStatus( + std::optional const& atLeastRequestNum = std::nullopt, bool markComplete = false) override { - PYBIND11_OVERLOAD_PURE(void, tb::BaseCacheTransceiver, checkContextTransferStatus, atLeastRequestNum); + PYBIND11_OVERLOAD_PURE( + RequestStatuses, tb::BaseCacheTransceiver, checkContextTransferStatus, atLeastRequestNum, markComplete); } void checkGenTransferStatus(std::optional const& atLeastRequestNum = std::nullopt) override @@ -84,8 +88,23 @@ void tb::CacheTransceiverBindings::initBindings(py::module_& m) .def("respond_and_send_async", &BaseCacheTransceiver::respondAndSendAsync) .def("request_and_receive_sync", &BaseCacheTransceiver::requestAndReceiveSync) .def("request_and_receive_async", &BaseCacheTransceiver::requestAndReceiveAsync) - .def("check_context_transfer_status", &BaseCacheTransceiver::checkContextTransferStatus, - py::call_guard()) + .def( + "check_context_transfer_status", + [](tb::BaseCacheTransceiver& self, std::optional const& atLeastRequestNum, bool markComplete = false) + { + RequestStatuses result; + { + py::gil_scoped_release release; + result = self.checkContextTransferStatus(atLeastRequestNum, markComplete); + } + + auto completedRequestIds + = std::vector(result.completedRequestIds.begin(), result.completedRequestIds.end()); + auto errorRequestIds + = std::vector(result.errorRequestIds.begin(), result.errorRequestIds.end()); + return py::make_tuple(completedRequestIds, errorRequestIds); + }, + py::arg("at_least_request_num") = std::nullopt, py::arg("mark_complete") = false) .def("check_gen_transfer_status", &BaseCacheTransceiver::checkGenTransferStatus, py::call_guard()) .def("check_gen_transfer_complete", &BaseCacheTransceiver::checkGenTransferComplete) diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index 8f54aca6c48..3f8acba9aec 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -113,6 +113,125 @@ class BatchStatePP(BatchState): finished_ctx_reqs: list[LlmRequest] = None +class AsyncTransferManager: + """ + Handle asynchronous transfer or KV cache after a request has completed. + When running with both the KV cache transceiver and the KV cache connector, we must ensure that BOTH transfers (if any) are completed before we can release the KV cache blocks. + The AsyncTransferManager has a few key responsibilities: + 1. Track requests in transfer. + 2. Pin blocks for reuse while blocks are in transfer. + 3. Unpin blocks after all transfers are complete. + + TODO(jthomson04): This only handles async send/saving, not loading. Loading kv cache is handled through a separate codepath. Eventually, we'll want to merge these two paths. + """ + + class RequestTransferMetadata: + + def __init__(self, block_id: Optional[int]): + self.block_id = block_id + self.counter = 0 + + def start_transfer(self): + self.counter += 1 + + def end_transfer(self) -> bool: + """ + Returns: + bool: True if there are no more transfers for this request + """ + self.counter -= 1 + return self.counter == 0 + + def __init__(self, + resource_manager: "ResourceManager", + should_store_blocks: bool = True): + self.resource_manager = resource_manager + self.kv_cache_manager = resource_manager.resource_managers.get( + ResourceManagerType.KV_CACHE_MANAGER) + + self.should_store_blocks = should_store_blocks + + # Mapping of request id to the LlmRequest + self._requests_in_transfer: Dict[int, LlmRequest] = dict() + + # Mapping of request id to the the request metadata + self._request_transfer_metadata: Dict[ + int, self.RequestTransferMetadata] = dict() + + def requests_in_transfer(self) -> Dict[int, LlmRequest]: + return self._requests_in_transfer + + def start_transfer(self, request: LlmRequest): + """ + Called when a Cache transceiver or connector transfer is started. + 1. Increment the counter for the request. + 2. Releases all resources except for the KV cache, if not already released. + 3. Store KV cache blocks for reuse. + """ + + req_id = request.py_request_id + + if req_id not in self._requests_in_transfer: + for resource_mgr_type in ( + ResourceManagerType.SEQ_SLOT_MANAGER, + ResourceManagerType.SPEC_RESOURCE_MANAGER): + if resource_mgr_type in self.resource_manager.resource_managers and self.resource_manager.resource_managers[ + resource_mgr_type] is not None: + self.resource_manager.resource_managers[ + resource_mgr_type].free_resources(request) + + request.state = LlmRequestState.DISAGG_CONTEXT_TRANS_IN_PROGRESS + + if self.should_store_blocks: + block_id = self.kv_cache_manager.store_blocks_for_reuse( + request, True) + else: + block_id = None + + self._requests_in_transfer[req_id] = request + self._request_transfer_metadata[ + req_id] = self.RequestTransferMetadata(block_id) + + self._request_transfer_metadata[req_id].start_transfer() + + def end_transfer(self, request: LlmRequest) -> bool: + """ + Called after a send of KV cache is complete. + 1. Decrements counter for request. + 2. If there are no more inflight transfers for this request, unpin the blocks and mark the request as complete. + + Returns: + bool: True if the request should be terminated after call to end_transfer + """ + try: + transfer_metadata = self._request_transfer_metadata[ + request.py_request_id] + except KeyError: + logger.warning( + f"Request {request.py_request_id} not found in transfer manager" + ) + return + + if transfer_metadata.end_transfer(): + self._requests_in_transfer.pop(request.py_request_id) + self._request_transfer_metadata.pop(request.py_request_id) + + if self.should_store_blocks: + self.kv_cache_manager.unpin_blocks_by_id( + transfer_metadata.block_id) + + # We don't want to overwrite any error state. + if request.state != LlmRequestState.DISAGG_TRANS_ERROR: + request.state = LlmRequestState.DISAGG_CONTEXT_COMPLETE + + return True + + return False + + def has_any_inflight_requests(self) -> bool: + return len(self._requests_in_transfer) > 0 + + class PyExecutor: def __init__(self, @@ -231,10 +350,10 @@ def __init__(self, self.max_num_active_requests = model_engine.get_max_num_sequences() self.active_requests: List[LlmRequest] = [] self.expected_num_active_requests = 0 - self.ctx_in_transmission_requests = dict() - self.ctx_in_transmission_counter = (1 if kv_cache_transceiver else - 0) + (1 if kv_connector_manager else - 0) + self.async_transfer_manager = AsyncTransferManager( + self.resource_manager, + should_store_blocks=self.block_reuse_enabled + and not self.kv_cache_manager.is_vswa) self.previous_batch: Optional[BatchState] = None self.has_previous_draft_tokens = False self.num_scheduled_requests: int = 0 @@ -358,6 +477,10 @@ def _maybe_init_kv_connector_manager(self): module.register_forward_hook( self.kv_connector_manager.layer_post_hook) + def _end_transfer_and_maybe_terminate(self, request: LlmRequest): + if self.async_transfer_manager.end_transfer(request): + self._terminate_request(request) + def _event_loop_wrapper(self): try: with customized_gc_thresholds( @@ -926,7 +1049,7 @@ def _pp_retry_until_can_schedule(self, scheduled_batch): raise RuntimeError( "KV cache transceiver is not enabled, but current rank cannot run first PP's schedule result due to limited KV cache resources. This is not expected." ) - if not self.ctx_in_transmission_requests: + if not self.async_transfer_manager.has_any_inflight_requests(): raise RuntimeError( "No context cache transmission is in progress, but current rank cannot run first PP's schedule result due to limited KV cache resources. This is not expected." ) @@ -945,7 +1068,6 @@ def _pp_retry_until_can_schedule(self, scheduled_batch): # Let cache transceiver finish at least one cache transmission and release requests' KV cache resources self._check_disagg_ctx_cache_transfer_status(1) self._check_kv_transfer_timeout() - self._terminate_disagg_ctx_finished_requests() else: raise RuntimeError( f"Reach maximum PP retry count ({self.pp_scheduler_max_retry_count}) but still cannot run first PP's schedule result. Please consider increasing the KV cache size by setting `free_gpu_memory_fraction` to a larger value. Or you can set `TLLM_PP_SCHEDULER_MAX_RETRY_COUNT` to a larger value to allow more retries." @@ -1143,20 +1265,8 @@ def _executor_loop_pp(self): sample_state.scheduled_requests.context_requests = previous_batch.finished_ctx_reqs self._update_requests(previous_batch.sample_state) - if self.block_reuse_enabled and not self.kv_cache_manager.is_vswa and self.kv_cache_transceiver: - for req in previous_batch.scheduled_ctx_reqs: - if req.is_context_only_request and ( - req.is_context_finished - or req.is_finished_due_to_length): - block_id = self.kv_cache_manager.store_blocks_for_reuse( - req, True) - self.ctx_in_transmission_requests[ - req.py_request_id] = ( - (req, block_id, - self.ctx_in_transmission_counter)) - if self.kv_cache_transceiver: - self._send_disagg_ctx_cache( + self._send_kv_async( previous_batch.scheduled_ctx_reqs) self._handle_canceled_requests() @@ -1178,9 +1288,9 @@ def _executor_loop_pp(self): self.wait_on_pp_send_handles(prev_microbatch_id) self.micro_batches[prev_microbatch_id] = None - if self.kv_cache_transceiver and self.ctx_in_transmission_requests: + if self.kv_cache_transceiver and self.async_transfer_manager.has_any_inflight_requests( + ): self._check_kv_transfer_timeout() - self._terminate_disagg_ctx_finished_requests() if self._disagg_pp_termination_handler is not None: self._disagg_pp_termination_handler.terminate_pending_requests( @@ -1310,14 +1420,7 @@ def _kv_connector_terminate_requests(self): if self.kv_connector_manager: reqs_to_terminate = self.kv_connector_manager.get_finished() for req in reqs_to_terminate: - if req.py_request_id in self.ctx_in_transmission_requests: - request, block_id, counter = self.ctx_in_transmission_requests.pop( - req.py_request_id) - if counter == 1: - self.kv_cache_manager.unpin_blocks_by_id(block_id) - else: - self.ctx_in_transmission_requests[req.py_request_id] = ( - request, block_id, counter - 1) + self._end_transfer_and_maybe_terminate(req) def _kv_connector_wait_for_save(self): if self.kv_connector_manager is not None: @@ -1412,24 +1515,9 @@ def _executor_loop(self): self._update_request_states(scheduled_batch) self._update_requests(sample_state, self.resource_manager) - if self.block_reuse_enabled and not self.kv_cache_manager.is_vswa and self.kv_cache_transceiver: - for req in scheduled_batch.context_requests: - if req.is_context_only_request and ( - req.is_context_finished - or req.is_finished_due_to_length): - block_id = self.kv_cache_manager.store_blocks_for_reuse( - req, True) - self.ctx_in_transmission_requests[ - req.py_request_id] = ( - (req, block_id, - self.ctx_in_transmission_counter)) - if self.kv_cache_transceiver: - ctx_transmission_reqs = self._send_disagg_ctx_cache( - scheduled_batch.context_requests) - # For context only req in transmission, we reset the state since sampler might have changed it - for req in ctx_transmission_reqs: - req.state = LlmRequestState.DISAGG_CONTEXT_TRANS_IN_PROGRESS + self._send_kv_async(scheduled_batch.context_requests + + scheduled_batch.generation_requests) self._handle_canceled_requests() finished_requests = self._handle_responses() @@ -1443,9 +1531,9 @@ def _executor_loop(self): if self.enable_kv_cache_events: self._add_kv_cache_events() - if self.kv_cache_transceiver and self.ctx_in_transmission_requests: + if self.kv_cache_transceiver and self.async_transfer_manager.has_any_inflight_requests( + ): self._check_kv_transfer_timeout() - self._terminate_disagg_ctx_finished_requests() self._kv_connector_terminate_requests() @@ -1662,18 +1750,6 @@ def _executor_loop_overlap(self): if self.previous_batch is not None: self._update_requests(self.previous_batch.sample_state) - if self.block_reuse_enabled and not self.kv_cache_manager.is_vswa and self.kv_cache_transceiver: - for req in self.previous_batch.sample_state.scheduled_requests.context_requests: - if req.is_context_only_request and ( - req.is_context_finished - or req.is_finished_due_to_length): - block_id = self.kv_cache_manager.store_blocks_for_reuse( - req, True) - self.ctx_in_transmission_requests[ - req.py_request_id] = ( - (req, block_id, - self.ctx_in_transmission_counter)) - if self.drafter is not None and self.use_spec_decode: # Cleanup previous draft resources used in the draft model self.drafter.cleanup_previous_draft_resources() @@ -1698,9 +1774,8 @@ def _executor_loop_overlap(self): self._update_request_states(scheduled_batch) - ctx_transmission_reqs = self._send_disagg_ctx_cache( - scheduled_batch.context_requests - ) if self.kv_cache_transceiver else [] + ctx_transmission_reqs = self._send_kv_async( + scheduled_batch.all_requests()) if self.previous_batch is not None: self._process_previous_batch() @@ -1716,9 +1791,9 @@ def _executor_loop_overlap(self): iter_stats=iter_stats, ctx_transmission_reqs=ctx_transmission_reqs) - if self.kv_cache_transceiver and self.ctx_in_transmission_requests: + if self.kv_cache_transceiver and self.async_transfer_manager.has_any_inflight_requests( + ): self._check_kv_transfer_timeout() - self._terminate_disagg_ctx_finished_requests() self._kv_connector_terminate_requests() @@ -2051,7 +2126,7 @@ def flag_if_kv_transfer_timed_out(req: LlmRequest, type: str) -> None: ) req.py_kv_transfer_timed_out = True - for req, _, _ in self.ctx_in_transmission_requests.values(): + for req in self.async_transfer_manager.requests_in_transfer().values(): flag_if_kv_transfer_timed_out(req, "context") for req in self.active_requests: @@ -2171,35 +2246,45 @@ def _recv_disagg_gen_cache(self, new_gen_reqs): return - @nvtx_range("_send_disagg_ctx_cache") - def _send_disagg_ctx_cache(self, scheduled_ctx_requests): - if (scheduled_ctx_requests is None or len(scheduled_ctx_requests) == 0): - return [] - for req in scheduled_ctx_requests: - if req.is_context_only_request and (req.is_context_finished or - req.is_finished_due_to_length): - self.kv_cache_transceiver.respond_and_send_async(req) - for resource_mgr_type in ( - ResourceManagerType.SEQ_SLOT_MANAGER, - ResourceManagerType.SPEC_RESOURCE_MANAGER): - if resource_mgr_type in self.resource_manager.resource_managers and self.resource_manager.resource_managers[ - resource_mgr_type] is not None: - self.resource_manager.resource_managers[ - resource_mgr_type].free_resources(req) - - self._check_disagg_ctx_cache_transfer_status(0) - - # Keep track of ctx requests that are in transmission - ctx_transmission_reqs = [ - req for req in scheduled_ctx_requests - if req.state == LlmRequestState.DISAGG_CONTEXT_TRANS_IN_PROGRESS - ] + @nvtx_range("_send_kv_async") + def _send_kv_async(self, scheduled_requests: List[LlmRequest]): - if self.kv_cache_transceiver.kv_transfer_timeout_ms is not None: - for req in ctx_transmission_reqs: - req.py_kv_transfer_start_time = time.time() + def kv_connector_request_finished(req: LlmRequest): + try: + cache_block_ids = self.kv_cache_manager.get_cache_indices(req) + except Exception as e: + logger.warning( + f"Unable to get cache blocks for request {req.py_request_id}. Skipping asynchronous saving: {e}" + ) + else: + if self.kv_connector_manager.request_finished( + req, cache_block_ids): + self.async_transfer_manager.start_transfer(req) + + if self.kv_cache_transceiver: + for req in scheduled_requests: + if req.is_context_only_request and ( + req.is_context_finished + or req.is_finished_due_to_length): + self.kv_cache_transceiver.respond_and_send_async(req) + + self.async_transfer_manager.start_transfer(req) + + if self.kv_cache_transceiver.kv_transfer_timeout_ms is not None: + req.py_kv_transfer_start_time = time.time() - return ctx_transmission_reqs + if self.kv_connector_manager: + if not self.disable_overlap_scheduler: + requests = self.previous_batch.sample_state.scheduled_requests.all_requests( + ) if self.previous_batch is not None else [] + else: + requests = scheduled_requests + for req in requests: + if req.is_finished: + kv_connector_request_finished(req) + + if self.kv_cache_transceiver: + self._check_disagg_ctx_cache_transfer_status(0) def _get_disagg_reqs_in_error_state(self): return [ @@ -2217,7 +2302,41 @@ def _check_cache_transfer_errors(self, error_msg_prefix: str): @nvtx_range("_check_disagg_ctx_cache_transfer_status") def _check_disagg_ctx_cache_transfer_status(self, atLeastNum: int = 0): - self.kv_cache_transceiver.check_context_transfer_status(atLeastNum) + finished_requests, error_requests = self.kv_cache_transceiver.check_context_transfer_status( + atLeastNum) + + completed_req_ids = set(finished_requests + error_requests) + + requests_in_transfer = self.async_transfer_manager.requests_in_transfer( + ) + + for request_id in completed_req_ids: + + if request_id not in requests_in_transfer: + logger.warning( + f"Request {request_id} not found in transfer manager") + continue + + request = requests_in_transfer[request_id] + + self._end_transfer_and_maybe_terminate(request) + + # The set of requests in transfer may have changed since we terminated some requests. + requests_in_transfer = self.async_transfer_manager.requests_in_transfer( + ) + + for request_id in list(requests_in_transfer.keys()): + request = requests_in_transfer[request_id] + if request.py_kv_transfer_timed_out and request_id not in completed_req_ids: + is_cancelled = self.kv_cache_transceiver.cancel_request(request) + # If cancel is successful, mark as complete so it can be cleaned up + # Otherwise, try at next iteration + if is_cancelled: + request.py_kv_transfer_start_time = None + request.state = LlmRequestState.DISAGG_CONTEXT_COMPLETE + + self._end_transfer_and_maybe_terminate(request) + self._check_cache_transfer_errors("context requests") @nvtx_range("_check_disagg_gen_cache_transfer_status") @@ -2423,24 +2542,6 @@ def _terminate_request(self, request: LlmRequest): self._do_terminate_request(request) def _do_terminate_request(self, request: LlmRequest): - if self.kv_connector_manager is not None: - # Only call request_finished on the connector if the request has already been added to the kv cache manager. - try: - cache_block_ids = self.kv_cache_manager.get_cache_indices( - request) - except IndexError: - # If the request has not yet been added to the kv cache manager, - # we still need to free resources corresponding to other resource managers. - self.resource_manager.free_resources(request) - else: - if self.kv_connector_manager.request_finished( - request, - cache_block_ids) and not self.kv_cache_transceiver: - block_id = self.kv_cache_manager.store_blocks_for_reuse( - request, True) - self.ctx_in_transmission_requests[request.py_request_id] = ( - (request, block_id, self.ctx_in_transmission_counter)) - self.resource_manager.free_resources(request) if self.gather_all_responses or self.dist.rank == 0: @@ -2627,12 +2728,7 @@ def _handle_responses(self): if self.block_reuse_enabled and not self.kv_cache_manager.is_vswa: requests_to_terminate.append(request) else: - if request.is_disagg_context_transmission_state: - self.ctx_in_transmission_requests[ - request.py_request_id] = ( - (request, None, - self.ctx_in_transmission_counter)) - else: + if not request.is_disagg_context_transmission_state: requests_to_terminate.append(request) else: new_active_requests.append(request) @@ -2645,35 +2741,6 @@ def _handle_responses(self): self._terminate_request(request) return requests_to_terminate - @nvtx_range("_terminate_disagg_ctx_finished_requests") - def _terminate_disagg_ctx_finished_requests(self): - # make a copy of the keys, since we are modifying the dictionary in the loop - in_transmission_requests_id = list( - self.ctx_in_transmission_requests.keys()) - for request_id in in_transmission_requests_id: - request, block_id, counter = self.ctx_in_transmission_requests[ - request_id] - - if request.py_kv_transfer_timed_out: - is_cancelled = self.kv_cache_transceiver.cancel_request(request) - # If cancel is successful, mark as complete so it can be cleaned up - # Otherwise, try at next iteration - if is_cancelled: - request.py_kv_transfer_start_time = None - request.state = LlmRequestState.DISAGG_CONTEXT_COMPLETE - - if request.is_disagg_context_complete_state: - del self.ctx_in_transmission_requests[request_id] - if not self.block_reuse_enabled or self.kv_cache_manager.is_vswa: - self._terminate_request(request) - elif counter == 1: - self.kv_cache_manager.unpin_blocks_by_id(block_id) - else: - self.ctx_in_transmission_requests[request_id] = ((request, - block_id, - counter - - 1)) - def _handle_logits_communication(self, previous_batch, prev_microbatch_id): """Handle logits communication between pipeline parallel ranks. diff --git a/tests/integration/test_lists/test-db/l0_a10.yml b/tests/integration/test_lists/test-db/l0_a10.yml index adb11c9416a..fefe715d409 100644 --- a/tests/integration/test_lists/test-db/l0_a10.yml +++ b/tests/integration/test_lists/test-db/l0_a10.yml @@ -21,6 +21,7 @@ l0_a10: - unittest/_torch/modeling/test_modeling_mistral.py - unittest/_torch/modeling/test_modeling_pixtral.py - unittest/_torch/sampler/test_trtllm_sampler.py + - unittest/_torch/executor/test_async_transfer_manager.py - unittest/_torch/executor/test_scheduler_serializable_output.py # NOTE: this is a CPU-only test, but we do not have a dedicated job for this (and therefore no # test list either). diff --git a/tests/unittest/_torch/executor/test_async_transfer_manager.py b/tests/unittest/_torch/executor/test_async_transfer_manager.py new file mode 100644 index 00000000000..cc00d9f8065 --- /dev/null +++ b/tests/unittest/_torch/executor/test_async_transfer_manager.py @@ -0,0 +1,182 @@ +# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest.mock import MagicMock + +from tensorrt_llm._torch.pyexecutor.py_executor import AsyncTransferManager +from tensorrt_llm._torch.pyexecutor.resource_manager import ResourceManagerType +from tensorrt_llm.bindings import LlmRequestState + + +def create_mock_request(request_id: int): + """Create a mock LlmRequest with the given request ID.""" + request = MagicMock() + request.py_request_id = request_id + request.state = LlmRequestState.GENERATION_IN_PROGRESS + return request + + +def create_mock_resource_manager( + kv_cache_manager=None, + seq_slot_manager=None, + spec_resource_manager=None, +): + """Create a mock ResourceManager with the specified resource managers.""" + resource_manager = MagicMock() + resource_manager.resource_managers = {} + + if kv_cache_manager is not None: + resource_manager.resource_managers[ResourceManagerType.KV_CACHE_MANAGER] = kv_cache_manager + + if seq_slot_manager is not None: + resource_manager.resource_managers[ResourceManagerType.SEQ_SLOT_MANAGER] = seq_slot_manager + + if spec_resource_manager is not None: + resource_manager.resource_managers[ResourceManagerType.SPEC_RESOURCE_MANAGER] = ( + spec_resource_manager + ) + + return resource_manager + + +def test_start_transfer_single_request(): + """Test starting a single transfer.""" + kv_cache_manager = MagicMock() + kv_cache_manager.store_blocks_for_reuse.return_value = 100 + seq_slot_manager = MagicMock() + resource_manager = create_mock_resource_manager( + kv_cache_manager=kv_cache_manager, seq_slot_manager=seq_slot_manager + ) + manager = AsyncTransferManager(resource_manager) + + request = create_mock_request(42) + manager.start_transfer(request) + + # Check request is tracked + assert 42 in manager._requests + + transfer_metadata = manager._request_transfer_metadata[42] + + assert transfer_metadata.block_id == 100 + assert transfer_metadata.counter == 1 + + # Check state was updated + assert request.state == LlmRequestState.DISAGG_CONTEXT_TRANS_IN_PROGRESS + + # Check KV cache manager was called + kv_cache_manager.store_blocks_for_reuse.assert_called_once_with(request, True) + + # Check seq slot manager was called to free resources + seq_slot_manager.free_resources.assert_called_once_with(request) + + manager.end_transfer(request) + kv_cache_manager.unpin_blocks_by_id.assert_called_once() + + +def test_start_transfer_multiple_transfers_same_request(): + """Test starting multiple transfers for the same request.""" + kv_cache_manager = MagicMock() + kv_cache_manager.store_blocks_for_reuse.return_value = 100 + resource_manager = create_mock_resource_manager(kv_cache_manager=kv_cache_manager) + manager = AsyncTransferManager(resource_manager) + + request = create_mock_request(42) + manager.start_transfer(request) + manager.start_transfer(request) + manager.start_transfer(request) + + # Counter should be incremented + transfer_metadata = manager._request_transfer_metadata[42] + assert transfer_metadata.counter == 3 + + # store_blocks_for_reuse should only be called once + kv_cache_manager.store_blocks_for_reuse.assert_called_once() + + for _ in range(2): + manager.end_transfer(request) + kv_cache_manager.unpin_blocks_by_id.assert_not_called() + + manager.end_transfer(request) + kv_cache_manager.unpin_blocks_by_id.assert_called_once() + + +def test_transfer_without_storing_blocks(): + """Test starting a transfer with should_store_blocks=False.""" + kv_cache_manager = MagicMock() + kv_cache_manager.store_blocks_for_reuse.return_value = 0 + spec_resource_manager = MagicMock() + resource_manager = create_mock_resource_manager( + kv_cache_manager=kv_cache_manager, spec_resource_manager=spec_resource_manager + ) + manager = AsyncTransferManager(resource_manager, should_store_blocks=False) + + request = create_mock_request(42) + manager.start_transfer(request) + + # Check request is tracked + assert 42 in manager._requests + transfer_metadata = manager._request_transfer_metadata[42] + assert transfer_metadata.block_id is None # No block stored + assert transfer_metadata.counter == 1 + + # Check KV cache manager was NOT called + kv_cache_manager.store_blocks_for_reuse.assert_not_called() + spec_resource_manager.free_resources.assert_called_once_with(request) + + assert manager.end_transfer(request) + + kv_cache_manager.unpin_blocks_by_id.assert_not_called() + + +def test_end_transfer_preserves_error_state(): + """Test that end_transfer does not overwrite error state.""" + kv_cache_manager = MagicMock() + kv_cache_manager.store_blocks_for_reuse.return_value = 100 + resource_manager = create_mock_resource_manager(kv_cache_manager=kv_cache_manager) + manager = AsyncTransferManager(resource_manager) + + request = create_mock_request(42) + manager.start_transfer(request) + + # Set error state before end_transfer + request.state = LlmRequestState.DISAGG_TRANS_ERROR + + manager.end_transfer(request) + + # Error state should be preserved + assert request.state == LlmRequestState.DISAGG_TRANS_ERROR + + +def test_requests_in_transfer(): + """Test that requests_in_transfer returns correct mapping.""" + kv_cache_manager = MagicMock() + kv_cache_manager.store_blocks_for_reuse.return_value = 100 + resource_manager = create_mock_resource_manager(kv_cache_manager=kv_cache_manager) + manager = AsyncTransferManager(resource_manager) + + request1 = create_mock_request(1) + request2 = create_mock_request(2) + request3 = create_mock_request(3) + + manager.start_transfer(request1) + manager.start_transfer(request2) + manager.start_transfer(request3) + + in_transfer = manager.requests_in_transfer() + + assert len(in_transfer) == 3 + assert in_transfer[1] is request1 + assert in_transfer[2] is request2 + assert in_transfer[3] is request3