From 88234233d39adfa0f6363bac772e9afb32fc413c Mon Sep 17 00:00:00 2001 From: Iman Tabrizian <10105175+tabrizian@users.noreply.github.com> Date: Thu, 18 Dec 2025 00:12:48 -0800 Subject: [PATCH 1/2] [https://nvbugs/5689235][fix] Fix cancellation+chunked prefill+disagg Signed-off-by: Iman Tabrizian <10105175+tabrizian@users.noreply.github.com> --- .../batch_manager/kvCacheManager.h | 24 +++---- .../tensorrt_llm/batch_manager/llmRequest.h | 6 ++ .../batch_manager/kvCacheManager.cpp | 63 +++++++++++-------- cpp/tensorrt_llm/executor/executorImpl.cpp | 12 ++-- cpp/tensorrt_llm/executor/executorImpl.h | 4 +- .../nanobind/batch_manager/bindings.cpp | 1 + .../nanobind/batch_manager/kvCacheManager.cpp | 19 +++++- .../pybind/batch_manager/bindings.cpp | 1 + .../pybind/batch_manager/kvCacheManager.cpp | 21 ++++++- .../batch_manager/kvCacheManagerTest.cpp | 4 +- tensorrt_llm/_torch/pyexecutor/py_executor.py | 21 +++++-- .../_torch/pyexecutor/resource_manager.py | 22 +++++-- 12 files changed, 134 insertions(+), 64 deletions(-) diff --git a/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h b/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h index bbe23828bfd..62fc4fcb301 100644 --- a/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h +++ b/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h @@ -648,7 +648,7 @@ class WindowBlockManager void replaceSharedBlock(GenerationRequest& sequence, SizeType32 blockIdx); - [[nodiscard]] std::optional storeBlocksForReuse( + [[nodiscard]] std::vector storeBlocksForReuse( GenerationRequest& sequence, OptionalRef llmRequest, bool pinBlocks = false); void storeNewBlock(GenerationRequest& sequence, OptionalRef llmRequest); @@ -853,8 +853,8 @@ class WindowBlockManager //! \param blockKeys Key of each block. //! \param blockIds Id of each block. //! \param pinBlocks If true, increment ref count for blocks while storing (pin on store). - //! \return Pair of (num blocks stored for reuse, id of the last block stored if any). - [[nodiscard]] std::pair> storeBlocks( + //! \return Pair of (num blocks stored for reuse, vector of pinned block IDs). + [[nodiscard]] std::pair> storeBlocks( std::vector const& blockKeys, std::vector const& blockIds, bool pinBlocks = false); @@ -886,8 +886,8 @@ class WindowBlockManager [[nodiscard]] std::shared_ptr findBlocksInReuseTreeByBlockKey(BlockKey const& blockKey); - //! \brief Unpin blocks by starting from a block id and walking prev pointers. - void unpinBlocksById(KVCacheBlock::IdType blockId); + //! \brief Unpin blocks by block ids directly + void unpinBlocksById(std::vector const& blockIds); void initializeSequenceStorageValidity(LlmRequest::RequestIdType requestId) { @@ -1103,7 +1103,7 @@ class BlockManager std::optional releaseBlocks( GenerationRequest& sequence, OptionalRef llmRequest = std::nullopt, bool pinBlocks = false); - [[nodiscard]] std::optional storeBlocksForReuse( + [[nodiscard]] std::vector storeBlocksForReuse( GenerationRequest& sequence, OptionalRef llmRequest = std::nullopt, bool pinBlocks = false); void schedulingReleaseBlocks(LlmRequest::RequestIdType requestId); @@ -1112,7 +1112,7 @@ class BlockManager /// @param sequence The generation request whose blocks should be pinned. void pinBlocks(GenerationRequest& sequence); - void unpinBlocksById(KVCacheBlock::IdType blockId); + void unpinBlocksById(std::vector const& blockIds); void releaseLastBlock(GenerationRequest& sequence, SizeType32 windowSize); @@ -1133,7 +1133,7 @@ class BlockManager void offloadBlock(BlockPtr const& block, SizeType32 windowSize, executor::KvCacheTransferMode mode = executor::KvCacheTransferMode::DRAM, std::string const& directory = ""); - [[nodiscard]] std::pair> storeBlocks( + [[nodiscard]] std::pair> storeBlocks( std::vector const& blockKeys, std::vector const& blockIds, SizeType32 windowSize, bool pinBlocks = false) { @@ -1584,7 +1584,7 @@ class BaseKVCacheManager virtual void storeNewBlock(LlmRequest const& llmRequest) = 0; /// \brief Store blocks for reuse for a given request id - [[nodiscard]] virtual std::optional storeBlocksForReuse( + [[nodiscard]] virtual std::vector storeBlocksForReuse( LlmRequest::RequestIdType requestId, OptionalRef llmRequest, bool pinBlocks = false) = 0; @@ -1678,7 +1678,7 @@ class BaseKVCacheManager BlockKey const& blockKey, SizeType32 windowSize) = 0; - virtual void unpinBlocksById(KVCacheBlock::IdType blockId) = 0; + virtual void unpinBlocksById(std::vector const& blockIds) = 0; }; class KVCacheManager : public BaseKVCacheManager @@ -1939,7 +1939,7 @@ class KVCacheManager : public BaseKVCacheManager //! \brief Store newest blocks for reuse void storeNewBlock(LlmRequest const& llmRequest) override; - [[nodiscard]] std::optional storeBlocksForReuse( + [[nodiscard]] std::vector storeBlocksForReuse( LlmRequest::RequestIdType requestId, OptionalRef llmRequest, bool pinBlocks = false) override; [[nodiscard]] static SizeType32 getSinkBubbleLength(SizeType32 sinkTokenLen, SizeType32 tokensPerBlock); @@ -1960,7 +1960,7 @@ class KVCacheManager : public BaseKVCacheManager void pinBlocks(LlmRequest::RequestIdType requestId) override; - void unpinBlocksById(KVCacheBlock::IdType blockId) override; + void unpinBlocksById(std::vector const& blockIds) override; std::optional getLastBlockId(LlmRequest::RequestIdType requestId) const override; diff --git a/cpp/include/tensorrt_llm/batch_manager/llmRequest.h b/cpp/include/tensorrt_llm/batch_manager/llmRequest.h index 276b0a6483b..5757c57362f 100644 --- a/cpp/include/tensorrt_llm/batch_manager/llmRequest.h +++ b/cpp/include/tensorrt_llm/batch_manager/llmRequest.h @@ -1667,6 +1667,12 @@ class GenericLlmRequest [](auto reason) { return reason == executor::FinishReason::kLENGTH; }); } + [[nodiscard]] bool isFinishedDueToCancellation() const noexcept + { + return std::all_of(mFinishReasons.begin(), mFinishReasons.end(), + [](auto reason) { return reason == executor::FinishReason::kCANCELLED; }); + } + [[nodiscard]] bool isTimedOut() const { if (!mAllottedTimeMs.has_value()) diff --git a/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp b/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp index 38fde72225b..706febed748 100644 --- a/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp +++ b/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp @@ -1556,7 +1556,7 @@ void WindowBlockManager::allocateBlock(GenerationRequest& sequence, bool shareAm } } -std::pair> WindowBlockManager::storeBlocks( +std::pair> WindowBlockManager::storeBlocks( std::vector const& blockKeys, std::vector const& blockIds, bool pinBlocks) { SizeType32 numBlocksStoredForReuse = 0; @@ -1569,7 +1569,7 @@ std::pair> WindowBlockManager::s auto numBlocks = blockKeys.size(); std::vector storedBlocks; - std::optional lastStoredId = std::nullopt; + std::vector pinnedBlockIds; for (std::size_t blockCnt = 0; blockCnt < numBlocks; ++blockCnt) { auto const bid = blockIds[blockCnt]; @@ -1620,14 +1620,14 @@ std::pair> WindowBlockManager::s if (pinBlocks) { searchRoot->incRefCount(); + pinnedBlockIds.push_back(searchRoot->getBlockId()); } - lastStoredId = searchRoot->getBlockId(); } if (mEventManager) { mEventManager->enqueueStoredEvent(storedBlocks, mWindowSize); } - return {numBlocksStoredForReuse, lastStoredId}; + return {numBlocksStoredForReuse, pinnedBlockIds}; } void BlockManager::replaceSharedBlock(GenerationRequest& sequence, SizeType32 windowSize, SizeType32 blockIdx) @@ -1715,15 +1715,15 @@ std::deque BlockManager::getLatestEvents(std::optionalgetEvents(timeout) : std::deque{}; } -std::optional BlockManager::storeBlocksForReuse( +std::vector BlockManager::storeBlocksForReuse( GenerationRequest& sequence, OptionalRef llmRequest, bool pinBlocks) { - std::optional lastStoredId = std::nullopt; + std::vector pinnedBlockIds; for (auto& [_, manager] : mWindowBlockManagers) { - lastStoredId = manager.storeBlocksForReuse(sequence, llmRequest, pinBlocks); + pinnedBlockIds = manager.storeBlocksForReuse(sequence, llmRequest, pinBlocks); } - return lastStoredId; + return pinnedBlockIds; } std::optional BlockManager::releaseBlocks( @@ -1767,7 +1767,7 @@ void BlockManager::pinBlocks(GenerationRequest& sequence) } } -void BlockManager::unpinBlocksById(KVCacheBlock::IdType blockId) +void BlockManager::unpinBlocksById(std::vector const& blockIds) { // Use the first window size if (mWindowBlockManagers.empty()) @@ -1775,7 +1775,7 @@ void BlockManager::unpinBlocksById(KVCacheBlock::IdType blockId) return; } auto& firstManager = mWindowBlockManagers.begin()->second; - firstManager.unpinBlocksById(blockId); + firstManager.unpinBlocksById(blockIds); } void WindowBlockManager::pinBlocks(GenerationRequest& sequence) @@ -1788,21 +1788,28 @@ void WindowBlockManager::pinBlocks(GenerationRequest& sequence) } } -void WindowBlockManager::unpinBlocksById(KVCacheBlock::IdType blockId) +void WindowBlockManager::unpinBlocksById(std::vector const& blockIds) { - if (blockId < 0 || static_cast(blockId) >= mAllBlocksById.size()) + if (blockIds.empty()) { return; } - auto block = mAllBlocksById[blockId]; - while (block && block->getBlockId() != KVCacheBlock::kCachedBlocksRootId) + + for (auto const& blockId : blockIds) { - block->decRefCount(); - if (!block->hasRefs()) + if (blockId < 0 || static_cast(blockId) >= mAllBlocksById.size()) { - mEvictionPolicy->releaseBlock(block); + continue; + } + auto block = mAllBlocksById[blockId]; + if (block && block->getBlockId() != KVCacheBlock::kCachedBlocksRootId) + { + block->decRefCount(); + if (!block->hasRefs()) + { + mEvictionPolicy->releaseBlock(block); + } } - block = std::move(block->getPrevBlock()); } } @@ -1870,7 +1877,7 @@ void WindowBlockManager::storeNewBlock(GenerationRequest& sequence, OptionalRef< (void) storeBlocks(std::move(blockKeys), cacheBlockIds[beamIdx]); } -std::optional WindowBlockManager::storeBlocksForReuse( +std::vector WindowBlockManager::storeBlocksForReuse( GenerationRequest& sequence, OptionalRef llmRequest, bool pinBlocks) { auto constexpr beamIdx = 0; @@ -1883,7 +1890,10 @@ std::optional WindowBlockManager::storeBlocksForReuse( auto const usableSize = static_cast(uniqueTokens.size()) - 1; auto blockedUniqueTokens = chopVectorIntoBlocks(uniqueTokens, usableSize, mTokensPerBlock, true); auto blockKeys = buildBlockKeys(blockedUniqueTokens, *llmRequest); - return storeBlocks(std::move(blockKeys), cacheBlockIds[beamIdx], pinBlocks).second; + + auto [numStored, pinnedBlockIds] = storeBlocks(std::move(blockKeys), cacheBlockIds[beamIdx], pinBlocks); + + return pinnedBlockIds; } std::optional WindowBlockManager::releaseBlocks( @@ -1922,7 +1932,7 @@ std::optional WindowBlockManager::releaseBlocks( std::transform(allocatedBlocks.begin(), allocatedBlocks.end(), cacheBlockIds.begin(), [](BlockPtr const& block) { return block->getBlockId(); }); - auto [numBlocksStoredForReuse, lastStoredId] = storeBlocks(std::move(blockKeys), cacheBlockIds); + auto [numBlocksStoredForReuse, pinnedBlockIds] = storeBlocks(std::move(blockKeys), cacheBlockIds); TLLM_LOG_DEBUG("%s::releaseBlocks Request %lu, %d blocks stored for reuse", mLogPrefix.c_str(), sequence.getRequestId(), numBlocksStoredForReuse); } @@ -2499,15 +2509,14 @@ std::optional KVCacheManager::removeSequence( return lastStoredId; } -std::optional KVCacheManager::storeBlocksForReuse( +std::vector KVCacheManager::storeBlocksForReuse( RequestIdType requestId, OptionalRef llmRequest, bool pinBlocks) { TLLM_LOG_TRACE("[%s]::%s start", isCrossKv() ? "CROSS" : "SELF", __PRETTY_FUNCTION__); auto& sequence = getSequence(requestId); - std::optional lastStoredId - = mBlockManager.storeBlocksForReuse(sequence, llmRequest, pinBlocks); + auto pinnedBlockIds = mBlockManager.storeBlocksForReuse(sequence, llmRequest, pinBlocks); TLLM_LOG_TRACE("[%s]::%s stop", isCrossKv() ? "CROSS" : "SELF", __PRETTY_FUNCTION__); - return lastStoredId; + return pinnedBlockIds; } void KVCacheManager::schedulingRemoveSequence(RequestIdType requestId) @@ -2522,9 +2531,9 @@ void KVCacheManager::pinBlocks(RequestIdType requestId) mBlockManager.pinBlocks(sequence); } -void KVCacheManager::unpinBlocksById(KVCacheBlock::IdType blockId) +void KVCacheManager::unpinBlocksById(std::vector const& blockIds) { - mBlockManager.unpinBlocksById(blockId); + mBlockManager.unpinBlocksById(blockIds); } SizeType32 KVCacheManager::copyBlockOffsets(ITensor& output, SizeType32 outputSlotOffset, RequestIdType requestId) const diff --git a/cpp/tensorrt_llm/executor/executorImpl.cpp b/cpp/tensorrt_llm/executor/executorImpl.cpp index c8cddea5d6c..2abfdb51caa 100644 --- a/cpp/tensorrt_llm/executor/executorImpl.cpp +++ b/cpp/tensorrt_llm/executor/executorImpl.cpp @@ -2179,11 +2179,11 @@ void Executor::Impl::terminateContextFinishedRequests(InTransList& inTransmissio auto req = item.request; if (req->isDisaggContextCompleteState()) { - // If lastBlockId was tracked, unpin it. Otherwise, just terminate. + // If pinnedBlockIds were tracked, unpin them. Otherwise, just terminate. auto kvMgr = mModel->getKVCacheManager(); - if (kvMgr && item.lastBlockId.has_value()) + if (kvMgr && !item.pinnedBlockIds.empty()) { - kvMgr->unpinBlocksById(item.lastBlockId.value()); + kvMgr->unpinBlocksById(item.pinnedBlockIds); } else { @@ -2234,14 +2234,14 @@ Executor::Impl::RequestList Executor::Impl::populateNewResponses( // move the in transmission requests to another tracker if (llmReq->isDisaggContextTransmissionState()) { - std::optional lastBlockId{}; + std::vector pinnedBlockIds{}; auto kvMgr = mModel->getKVCacheManager(); if (kvMgr && kvMgr->isEnableBlockReuse() && !kvMgr->getBlockManager().isVariableWindow()) { - lastBlockId = kvMgr->storeBlocksForReuse(llmReq->mRequestId, llmReq, /*pinBlocks=*/true); + pinnedBlockIds = kvMgr->storeBlocksForReuse(llmReq->mRequestId, llmReq, /*pinBlocks=*/true); mModel->terminateRequest(llmReq); } - inTransmissionRequests.push_back(InTransmissionItem{*it, lastBlockId}); + inTransmissionRequests.push_back(InTransmissionItem{*it, pinnedBlockIds}); } finishedRequests.push_back(*it); it = activeRequests.erase(it); diff --git a/cpp/tensorrt_llm/executor/executorImpl.h b/cpp/tensorrt_llm/executor/executorImpl.h index 19bd00bd65b..a7f7aef075d 100644 --- a/cpp/tensorrt_llm/executor/executorImpl.h +++ b/cpp/tensorrt_llm/executor/executorImpl.h @@ -80,12 +80,12 @@ class Executor::Impl using RequestList = std::list; // When block reuse is enabled for context worker for disaggregated serving, - // we need to store the last block id so that we can unpin the block when + // we need to store the pinned block ids so that we can unpin them when // the request is finished. struct InTransmissionItem { LlmRequestPtr request; - std::optional lastBlockId; + std::vector pinnedBlockIds; }; using InTransList = std::list; diff --git a/cpp/tensorrt_llm/nanobind/batch_manager/bindings.cpp b/cpp/tensorrt_llm/nanobind/batch_manager/bindings.cpp index 17c27f43bed..b88e140469b 100644 --- a/cpp/tensorrt_llm/nanobind/batch_manager/bindings.cpp +++ b/cpp/tensorrt_llm/nanobind/batch_manager/bindings.cpp @@ -161,6 +161,7 @@ void initBindings(nb::module_& m) .def("set_finished_reason", &GenLlmReq::setFinishedReason, nb::arg("finish_reason"), nb::arg("beam")) .def_prop_ro("is_finished", &GenLlmReq::isFinished) .def_prop_ro("is_finished_due_to_length", &GenLlmReq::isFinishedDueToLength) + .def_prop_ro("is_finished_due_to_cancellation", &GenLlmReq::isFinishedDueToCancellation) .def_prop_rw( "context_current_position", &GenLlmReq::getContextCurrentPosition, &GenLlmReq::setContextCurrentPosition) .def_prop_ro("prepopulated_prompt_len", &GenLlmReq::getPrepopulatedPromptLen) diff --git a/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp b/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp index 7a3bcae7cf1..4c9b389e7d0 100644 --- a/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp +++ b/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp @@ -123,7 +123,7 @@ class PyKvCacheManager : public tbk::BaseKVCacheManager NB_OVERRIDE_PURE(removeSequence, requestId, llmRequest, pinOnRelease); } - std::optional storeBlocksForReuse(tb::LlmRequest::RequestIdType requestId, + std::vector storeBlocksForReuse(tb::LlmRequest::RequestIdType requestId, tensorrt_llm::common::OptionalRef llmRequest, bool pinBlocks) override { NB_OVERRIDE_PURE(storeBlocksForReuse, requestId, llmRequest, pinBlocks); @@ -363,7 +363,22 @@ void tb::kv_cache_manager::KVCacheManagerBindings::initBindings(nb::module_& m) nb::call_guard()) .def("add_token", &BaseKVCacheManager::addToken, nb::call_guard()) .def("add_sequence", &BaseKVCacheManager::addSequence, nb::call_guard()) - .def("remove_sequence", &BaseKVCacheManager::removeSequence, nb::call_guard()) + .def( + "remove_sequence", + [](tbk::BaseKVCacheManager& self, tb::LlmRequest::RequestIdType requestId, tb::LlmRequest const* llmRequest, + bool pinOnRelease) + { + if (llmRequest != nullptr) + { + return self.removeSequence(requestId, *llmRequest, pinOnRelease); + } + else + { + return self.removeSequence(requestId, std::nullopt, pinOnRelease); + } + }, + nb::arg("request_id"), nb::arg("llm_request") = nullptr, nb::arg("pin_on_release") = false, + nb::call_guard()) .def("pin_blocks", &BaseKVCacheManager::pinBlocks, nb::call_guard()) .def("scheduling_remove_sequence", &BaseKVCacheManager::schedulingRemoveSequence, nb::call_guard()) diff --git a/cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp b/cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp index 1d98b0c623a..498e6973f27 100644 --- a/cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp +++ b/cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp @@ -165,6 +165,7 @@ void initBindings(pybind11::module_& m) .def("set_finished_reason", &GenLlmReq::setFinishedReason, py::arg("finish_reason"), py::arg("beam")) .def_property_readonly("is_finished", &GenLlmReq::isFinished) .def_property_readonly("is_finished_due_to_length", &GenLlmReq::isFinishedDueToLength) + .def_property_readonly("is_finished_due_to_cancellation", &GenLlmReq::isFinishedDueToCancellation) .def_property( "context_current_position", &GenLlmReq::getContextCurrentPosition, &GenLlmReq::setContextCurrentPosition) .def_property_readonly("prepopulated_prompt_len", &GenLlmReq::getPrepopulatedPromptLen) diff --git a/cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp b/cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp index 6ab03315e1a..83bfe433815 100644 --- a/cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp +++ b/cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp @@ -111,10 +111,10 @@ class PyKvCacheManager : public tbk::BaseKVCacheManager requestId, llmRequest, pinOnRelease); } - std::optional storeBlocksForReuse(tb::LlmRequest::RequestIdType requestId, + std::vector storeBlocksForReuse(tb::LlmRequest::RequestIdType requestId, tensorrt_llm::common::OptionalRef llmRequest, bool pinBlocks) override { - PYBIND11_OVERLOAD_PURE(std::optional, tbk::BaseKVCacheManager, storeBlocksForReuse, + PYBIND11_OVERLOAD_PURE(std::vector, tbk::BaseKVCacheManager, storeBlocksForReuse, requestId, llmRequest, pinBlocks); } @@ -367,7 +367,22 @@ void tb::kv_cache_manager::KVCacheManagerBindings::initBindings(py::module_& m) py::call_guard()) .def("add_token", &BaseKVCacheManager::addToken, py::call_guard()) .def("add_sequence", &BaseKVCacheManager::addSequence, py::call_guard()) - .def("remove_sequence", &BaseKVCacheManager::removeSequence, py::call_guard()) + .def( + "remove_sequence", + [](tbk::BaseKVCacheManager& self, tb::LlmRequest::RequestIdType requestId, tb::LlmRequest const* llmRequest, + bool pinOnRelease) + { + if (llmRequest != nullptr) + { + return self.removeSequence(requestId, *llmRequest, pinOnRelease); + } + else + { + return self.removeSequence(requestId, std::nullopt, pinOnRelease); + } + }, + py::arg("request_id"), py::arg("llm_request") = nullptr, py::arg("pin_on_release") = false, + py::call_guard()) .def("pin_blocks", &BaseKVCacheManager::pinBlocks, py::call_guard()) .def("scheduling_remove_sequence", &BaseKVCacheManager::schedulingRemoveSequence, py::call_guard()) diff --git a/cpp/tests/unit_tests/batch_manager/kvCacheManagerTest.cpp b/cpp/tests/unit_tests/batch_manager/kvCacheManagerTest.cpp index 7bccc1bd118..0803140f1ca 100644 --- a/cpp/tests/unit_tests/batch_manager/kvCacheManagerTest.cpp +++ b/cpp/tests/unit_tests/batch_manager/kvCacheManagerTest.cpp @@ -4066,11 +4066,13 @@ TEST_F(KVCacheManagerTest, PinAndUnpinBlocksById) kvCacheManager.pinBlocks(requestId); auto lastBlockIdOpt = kvCacheManager.getLastBlockId(requestId); ASSERT_TRUE(lastBlockIdOpt.has_value()); + auto const& allBlockIds = kvCacheManager.getCacheBlockIds(requestId, maxAttentionWindow)[0]; + std::vector pinnedBlockIds(allBlockIds.begin(), allBlockIds.end()); (void) kvCacheManager.removeSequence(requestId, llmRequest); auto const freeAfterRemovePinned = kvCacheManager.getNumFreeBlocks(); EXPECT_LT(freeAfterRemovePinned, totalBlocks); - kvCacheManager.unpinBlocksById(lastBlockIdOpt.value()); + kvCacheManager.unpinBlocksById(pinnedBlockIds); auto const freeAfterUnpin = kvCacheManager.getNumFreeBlocks(); EXPECT_EQ(freeAfterUnpin, totalBlocks); } diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index 8f54aca6c48..679e56ebae3 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -1147,7 +1147,8 @@ def _executor_loop_pp(self): for req in previous_batch.scheduled_ctx_reqs: if req.is_context_only_request and ( req.is_context_finished - or req.is_finished_due_to_length): + or req.is_finished_due_to_length + ) and not req.is_finished_due_to_cancellation: block_id = self.kv_cache_manager.store_blocks_for_reuse( req, True) self.ctx_in_transmission_requests[ @@ -1416,7 +1417,8 @@ def _executor_loop(self): for req in scheduled_batch.context_requests: if req.is_context_only_request and ( req.is_context_finished - or req.is_finished_due_to_length): + or req.is_finished_due_to_length + ) and not req.is_finished_due_to_cancellation: block_id = self.kv_cache_manager.store_blocks_for_reuse( req, True) self.ctx_in_transmission_requests[ @@ -1666,7 +1668,8 @@ def _executor_loop_overlap(self): for req in self.previous_batch.sample_state.scheduled_requests.context_requests: if req.is_context_only_request and ( req.is_context_finished - or req.is_finished_due_to_length): + or req.is_finished_due_to_length + ) and not req.is_finished_due_to_cancellation: block_id = self.kv_cache_manager.store_blocks_for_reuse( req, True) self.ctx_in_transmission_requests[ @@ -2176,8 +2179,9 @@ def _send_disagg_ctx_cache(self, scheduled_ctx_requests): if (scheduled_ctx_requests is None or len(scheduled_ctx_requests) == 0): return [] for req in scheduled_ctx_requests: - if req.is_context_only_request and (req.is_context_finished or - req.is_finished_due_to_length): + if req.is_context_only_request and ( + req.is_context_finished or req.is_finished_due_to_length + ) and not req.is_finished_due_to_cancellation: self.kv_cache_transceiver.respond_and_send_async(req) for resource_mgr_type in ( ResourceManagerType.SEQ_SLOT_MANAGER, @@ -2441,7 +2445,12 @@ def _do_terminate_request(self, request: LlmRequest): self.ctx_in_transmission_requests[request.py_request_id] = ( (request, block_id, self.ctx_in_transmission_counter)) - self.resource_manager.free_resources(request) + store_blocks_for_reuse = not (self.block_reuse_enabled + and not self.kv_cache_manager.is_vswa + and self.kv_cache_transceiver + and request.is_context_only_request) + self.resource_manager.free_resources( + request, store_blocks_for_reuse=store_blocks_for_reuse) if self.gather_all_responses or self.dist.rank == 0: self.result_wait_queues.pop(request.py_request_id, None) diff --git a/tensorrt_llm/_torch/pyexecutor/resource_manager.py b/tensorrt_llm/_torch/pyexecutor/resource_manager.py index 619f8525c17..11d93a7dd24 100644 --- a/tensorrt_llm/_torch/pyexecutor/resource_manager.py +++ b/tensorrt_llm/_torch/pyexecutor/resource_manager.py @@ -685,8 +685,13 @@ def update_kv_cache_draft_token_location(self, None, ) - def free_resources(self, request: LlmRequest, pin_on_release: bool = False): - return self.impl.remove_sequence(request.py_request_id, request, + def free_resources(self, + request: LlmRequest, + pin_on_release: bool = False, + store_blocks_for_reuse: bool = True): + # When store_blocks_for_reuse is False, pass None to prevent block storage + llm_request = request if store_blocks_for_reuse else None + return self.impl.remove_sequence(request.py_request_id, llm_request, pin_on_release) def store_blocks_for_reuse(self, @@ -1430,10 +1435,17 @@ def update_resources(self, else: resource_manager.update_resources(scheduled_batch) - def free_resources(self, request: LlmRequest): - for _, resource_manager in reversed(self.resource_managers.items()): + def free_resources(self, + request: LlmRequest, + store_blocks_for_reuse: bool = True): + for resource_type, resource_manager in reversed( + self.resource_managers.items()): if hasattr(resource_manager, "free_resources"): - resource_manager.free_resources(request) + if resource_type == ResourceManagerType.KV_CACHE_MANAGER: + resource_manager.free_resources( + request, store_blocks_for_reuse=store_blocks_for_reuse) + else: + resource_manager.free_resources(request) def reorder_pipeline(self, resource_manager_list: list[ResourceManagerType]): From a90e549465aeb360f1313650be8c0fa6f6eef846 Mon Sep 17 00:00:00 2001 From: Iman Tabrizian <10105175+tabrizian@users.noreply.github.com> Date: Thu, 1 Jan 2026 19:18:51 +0000 Subject: [PATCH 2/2] Add testing and address review comments Signed-off-by: Iman Tabrizian <10105175+tabrizian@users.noreply.github.com> --- .../batch_manager/kvCacheManager.cpp | 6 +- .../nanobind/batch_manager/kvCacheManager.cpp | 17 +- .../pybind/batch_manager/kvCacheManager.cpp | 17 +- tensorrt_llm/_torch/pyexecutor/py_executor.py | 7 +- .../_torch/pyexecutor/resource_manager.py | 19 +- .../disagg_config_cancel_stress_test.yaml | 44 ++++ ...isagg_config_cancel_stress_test_large.yaml | 44 ++++ .../defs/disaggregated/test_disaggregated.py | 212 ++++++++++++++++++ .../test_lists/test-db/l0_dgx_h100.yml | 1 + 9 files changed, 310 insertions(+), 57 deletions(-) create mode 100644 tests/integration/defs/disaggregated/test_configs/disagg_config_cancel_stress_test.yaml create mode 100644 tests/integration/defs/disaggregated/test_configs/disagg_config_cancel_stress_test_large.yaml diff --git a/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp b/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp index 706febed748..c0bf858cc9e 100644 --- a/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp +++ b/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp @@ -1797,10 +1797,8 @@ void WindowBlockManager::unpinBlocksById(std::vector const for (auto const& blockId : blockIds) { - if (blockId < 0 || static_cast(blockId) >= mAllBlocksById.size()) - { - continue; - } + TLLM_CHECK_WITH_INFO(blockId >= 0 && static_cast(blockId) < mAllBlocksById.size(), + "Block id %d is out of range", blockId); auto block = mAllBlocksById[blockId]; if (block && block->getBlockId() != KVCacheBlock::kCachedBlocksRootId) { diff --git a/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp b/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp index 4c9b389e7d0..dbb2d366287 100644 --- a/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp +++ b/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp @@ -363,22 +363,7 @@ void tb::kv_cache_manager::KVCacheManagerBindings::initBindings(nb::module_& m) nb::call_guard()) .def("add_token", &BaseKVCacheManager::addToken, nb::call_guard()) .def("add_sequence", &BaseKVCacheManager::addSequence, nb::call_guard()) - .def( - "remove_sequence", - [](tbk::BaseKVCacheManager& self, tb::LlmRequest::RequestIdType requestId, tb::LlmRequest const* llmRequest, - bool pinOnRelease) - { - if (llmRequest != nullptr) - { - return self.removeSequence(requestId, *llmRequest, pinOnRelease); - } - else - { - return self.removeSequence(requestId, std::nullopt, pinOnRelease); - } - }, - nb::arg("request_id"), nb::arg("llm_request") = nullptr, nb::arg("pin_on_release") = false, - nb::call_guard()) + .def("remove_sequence", &BaseKVCacheManager::removeSequence, nb::call_guard()) .def("pin_blocks", &BaseKVCacheManager::pinBlocks, nb::call_guard()) .def("scheduling_remove_sequence", &BaseKVCacheManager::schedulingRemoveSequence, nb::call_guard()) diff --git a/cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp b/cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp index 83bfe433815..36a5bdcfa4b 100644 --- a/cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp +++ b/cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp @@ -367,22 +367,7 @@ void tb::kv_cache_manager::KVCacheManagerBindings::initBindings(py::module_& m) py::call_guard()) .def("add_token", &BaseKVCacheManager::addToken, py::call_guard()) .def("add_sequence", &BaseKVCacheManager::addSequence, py::call_guard()) - .def( - "remove_sequence", - [](tbk::BaseKVCacheManager& self, tb::LlmRequest::RequestIdType requestId, tb::LlmRequest const* llmRequest, - bool pinOnRelease) - { - if (llmRequest != nullptr) - { - return self.removeSequence(requestId, *llmRequest, pinOnRelease); - } - else - { - return self.removeSequence(requestId, std::nullopt, pinOnRelease); - } - }, - py::arg("request_id"), py::arg("llm_request") = nullptr, py::arg("pin_on_release") = false, - py::call_guard()) + .def("remove_sequence", &BaseKVCacheManager::removeSequence, py::call_guard()) .def("pin_blocks", &BaseKVCacheManager::pinBlocks, py::call_guard()) .def("scheduling_remove_sequence", &BaseKVCacheManager::schedulingRemoveSequence, py::call_guard()) diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index 679e56ebae3..221b186d8c0 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -2445,12 +2445,7 @@ def _do_terminate_request(self, request: LlmRequest): self.ctx_in_transmission_requests[request.py_request_id] = ( (request, block_id, self.ctx_in_transmission_counter)) - store_blocks_for_reuse = not (self.block_reuse_enabled - and not self.kv_cache_manager.is_vswa - and self.kv_cache_transceiver - and request.is_context_only_request) - self.resource_manager.free_resources( - request, store_blocks_for_reuse=store_blocks_for_reuse) + self.resource_manager.free_resources(request) if self.gather_all_responses or self.dist.rank == 0: self.result_wait_queues.pop(request.py_request_id, None) diff --git a/tensorrt_llm/_torch/pyexecutor/resource_manager.py b/tensorrt_llm/_torch/pyexecutor/resource_manager.py index 11d93a7dd24..537146577b6 100644 --- a/tensorrt_llm/_torch/pyexecutor/resource_manager.py +++ b/tensorrt_llm/_torch/pyexecutor/resource_manager.py @@ -685,13 +685,8 @@ def update_kv_cache_draft_token_location(self, None, ) - def free_resources(self, - request: LlmRequest, - pin_on_release: bool = False, - store_blocks_for_reuse: bool = True): - # When store_blocks_for_reuse is False, pass None to prevent block storage - llm_request = request if store_blocks_for_reuse else None - return self.impl.remove_sequence(request.py_request_id, llm_request, + def free_resources(self, request: LlmRequest, pin_on_release: bool = False): + return self.impl.remove_sequence(request.py_request_id, request, pin_on_release) def store_blocks_for_reuse(self, @@ -1435,17 +1430,11 @@ def update_resources(self, else: resource_manager.update_resources(scheduled_batch) - def free_resources(self, - request: LlmRequest, - store_blocks_for_reuse: bool = True): + def free_resources(self, request: LlmRequest): for resource_type, resource_manager in reversed( self.resource_managers.items()): if hasattr(resource_manager, "free_resources"): - if resource_type == ResourceManagerType.KV_CACHE_MANAGER: - resource_manager.free_resources( - request, store_blocks_for_reuse=store_blocks_for_reuse) - else: - resource_manager.free_resources(request) + resource_manager.free_resources(request) def reorder_pipeline(self, resource_manager_list: list[ResourceManagerType]): diff --git a/tests/integration/defs/disaggregated/test_configs/disagg_config_cancel_stress_test.yaml b/tests/integration/defs/disaggregated/test_configs/disagg_config_cancel_stress_test.yaml new file mode 100644 index 00000000000..2795ca46bd3 --- /dev/null +++ b/tests/integration/defs/disaggregated/test_configs/disagg_config_cancel_stress_test.yaml @@ -0,0 +1,44 @@ +hostname: localhost +port: 8000 +model: DeepSeek-V3-Lite/bf16 +backend: "pytorch" +enable_autotuner: False +context_servers: + disable_overlap_scheduler: True + num_instances: 1 + tensor_parallel_size: 1 + pipeline_parallel_size: 1 + max_num_tokens: 16384 + max_seq_len: 32768 + enable_chunked_prefill: True + kv_cache_config: + enable_block_reuse: True + enable_partial_reuse: True + free_gpu_memory_fraction: 0.3 + cache_transceiver_config: + backend: "DEFAULT" + max_tokens_in_buffer: 32768 + cuda_graph_config: + enable_padding: True + max_batch_size: 1 + urls: + - "localhost:8001" +generation_servers: + num_instances: 1 + tensor_parallel_size: 1 + pipeline_parallel_size: 1 + max_num_tokens: 2048 + max_seq_len: 32768 + enable_chunked_prefill: True + kv_cache_config: + enable_block_reuse: True + enable_partial_reuse: True + free_gpu_memory_fraction: 0.85 + cache_transceiver_config: + backend: "DEFAULT" + max_tokens_in_buffer: 32768 + cuda_graph_config: + enable_padding: True + max_batch_size: 64 + urls: + - "localhost:8002" diff --git a/tests/integration/defs/disaggregated/test_configs/disagg_config_cancel_stress_test_large.yaml b/tests/integration/defs/disaggregated/test_configs/disagg_config_cancel_stress_test_large.yaml new file mode 100644 index 00000000000..5a538d1f714 --- /dev/null +++ b/tests/integration/defs/disaggregated/test_configs/disagg_config_cancel_stress_test_large.yaml @@ -0,0 +1,44 @@ +hostname: localhost +port: 8000 +model: DeepSeek-V3-0324-FP4 +backend: "pytorch" +enable_autotuner: False +context_servers: + disable_overlap_scheduler: True + num_instances: 1 + tensor_parallel_size: 4 + pipeline_parallel_size: 1 + max_num_tokens: 12000 + max_seq_len: 262144 + enable_chunked_prefill: True + kv_cache_config: + enable_block_reuse: True + enable_partial_reuse: True + free_gpu_memory_fraction: 0.2 + cache_transceiver_config: + backend: "DEFAULT" + max_tokens_in_buffer: 262144 + cuda_graph_config: + enable_padding: True + max_batch_size: 1 + urls: + - "localhost:8001" +generation_servers: + num_instances: 1 + tensor_parallel_size: 4 + pipeline_parallel_size: 1 + max_num_tokens: 2048 + max_seq_len: 262144 + enable_chunked_prefill: True + kv_cache_config: + enable_block_reuse: True + enable_partial_reuse: True + free_gpu_memory_fraction: 0.3 + cache_transceiver_config: + backend: "DEFAULT" + max_tokens_in_buffer: 262144 + cuda_graph_config: + enable_padding: True + max_batch_size: 11 + urls: + - "localhost:8002" diff --git a/tests/integration/defs/disaggregated/test_disaggregated.py b/tests/integration/defs/disaggregated/test_disaggregated.py index e7089ebe0fa..86ba1c45177 100644 --- a/tests/integration/defs/disaggregated/test_disaggregated.py +++ b/tests/integration/defs/disaggregated/test_disaggregated.py @@ -200,6 +200,10 @@ def get_test_config(test_desc, example_dir, test_root): "gpt_oss_120b_stress": (4, f"{test_configs_root}/disagg_config_ctxtp2_gentp2_gptoss_tllm.yaml"), + "cancel_stress_test": + (2, f"{test_configs_root}/disagg_config_cancel_stress_test.yaml"), + "cancel_stress_test_large": + (8, f"{test_configs_root}/disagg_config_cancel_stress_test_large.yaml"), } if test_desc not in config_map: @@ -2098,3 +2102,211 @@ def test_disaggregated_stress_test(disaggregated_test_root, threshold=test_config.accuracy_threshold, env=llm_venv._new_env, cwd=llm_venv.get_working_directory()) + + +def run_cancel_stress_test(server_url: str, + num_bursts: int = 5, + requests_per_burst: int = 32, + prompt_len_range: tuple = (2000, 8000), + cancel_after_range: tuple = (0.01, 0.1)): + """ + Stress test that sends requests with large contexts and cancels them + during prefill to test resource cleanup under cancellation. + + Args: + server_url: The server URL (e.g., "http://localhost:8000") + num_bursts: Number of request bursts to send + requests_per_burst: Number of concurrent requests per burst + prompt_len_range: (min, max) prompt length in tokens + cancel_after_range: (min, max) seconds to wait before cancelling + """ + import asyncio + import random + import time + + import aiohttp + + async def spam_and_cancel(session, req_id, url, prompt_len_range, + cancel_after_range): + """Send a request and cancel it during prefill.""" + prompt_len = random.randint(prompt_len_range[0], prompt_len_range[1]) + prompt = "test " * (prompt_len // 5) + + payload = { + "model": "test-model", + "prompt": prompt, + "max_tokens": 10, + "stream": True + } + + try: + cancel_after = random.uniform(cancel_after_range[0], + cancel_after_range[1]) + start = time.time() + async with session.post( + f"{url}/v1/completions", + json=payload, + timeout=aiohttp.ClientTimeout(total=60)) as resp: + async for line in resp.content: + if time.time() - start > cancel_after: + # Force disconnect during prefill + break + except Exception: + pass # Connection abort is expected + + async def run_bursts(): + async with aiohttp.ClientSession() as session: + for burst_idx in range(num_bursts): + tasks = [ + spam_and_cancel(session, i, server_url, prompt_len_range, + cancel_after_range) + for i in range(requests_per_burst) + ] + await asyncio.gather(*tasks) + logger.info( + f"Completed burst {burst_idx + 1}/{num_bursts} ({requests_per_burst} requests)" + ) + await asyncio.sleep(0.05) + + asyncio.run(run_bursts()) + + +def run_disaggregated_cancel_test(example_dir, + test_desc, + env=None, + cwd=None, + num_bursts=64, + requests_per_burst=64): + """Run disaggregated test with request cancellation stress test.""" + cleanup_output_files() + run_env = env.copy() + run_env["UCX_TLS"] = "^ib" + + num_ranks, config_file = get_test_config(test_desc, example_dir, + os.path.dirname(__file__)) + + workers_cmd = [ + 'mpirun', '--allow-run-as-root', '--oversubscribe', '-n', + str(num_ranks), 'trtllm-serve', 'disaggregated_mpi_worker', '-c', + config_file + ] + + server_start_timeout = 1200 + server_cmd = [ + 'trtllm-serve', 'disaggregated', '--server_start_timeout', + str(server_start_timeout), '-c', config_file + ] + server_host, server_port = get_disagg_server_url_from_cfg(config_file) + server_url = f"http://{server_host}:{server_port}" + + try: + with (open('output_workers.log', 'w') as output_workers, + popen(workers_cmd, + stdout=output_workers, + stderr=subprocess.STDOUT, + env=run_env, + cwd=cwd) as workers_proc, open('output_disagg.log', 'w') as + output_disagg, + popen(server_cmd, + stdout=output_disagg, + stderr=subprocess.STDOUT, + env=run_env, + cwd=cwd) as server_proc): + + # Wait for server to be ready + if not wait_for_server(server_host, + server_port, + timeout_seconds=server_start_timeout): + raise RuntimeError( + f"Disaggregated server did not become ready within {server_start_timeout} seconds" + ) + + # Run the cancel stress test + run_cancel_stress_test(server_url, + num_bursts=num_bursts, + requests_per_burst=requests_per_burst) + + # Verify server is still healthy after stress test by sending a normal request + client_dir = f"{example_dir}/clients" + client_cmd = [ + 'python3', f'{client_dir}/disagg_client.py', '-c', config_file, + '-p', f'{client_dir}/prompts.json', '--ignore-eos', + '--server-start-timeout', + str(server_start_timeout) + ] + check_call(client_cmd, + env=env, + poll_procs=[workers_proc, server_proc]) + + except Exception: + logger.error("-------- Workers output --------") + with open('output_workers.log', 'r') as f: + logger.error(f.read()) + + logger.error("-------- Disagg server output --------") + with open('output_disagg.log', 'r') as f: + logger.error(f.read()) + raise + finally: + if 'server_proc' in locals() and 'workers_proc' in locals(): + server_proc.terminate() + workers_proc.terminate() + server_proc.wait() + workers_proc.wait() + + +@pytest.mark.parametrize("deepseek_v3_model_root", ['DeepSeek-V3-Lite-bf16'], + indirect=True) +def test_disaggregated_cancel_large_context_requests(disaggregated_test_root, + disaggregated_example_root, + llm_venv, + deepseek_v3_model_root): + """ + Test that the disaggregated server handles request cancellations gracefully. + + This test sends bursts of requests with large contexts and cancels them + during prefill to stress test resource cleanup. + """ + src_dst_dict = { + deepseek_v3_model_root: + f"{llm_venv.get_working_directory()}/DeepSeek-V3-Lite/bf16", + } + for src, dst in src_dst_dict.items(): + if not os.path.islink(dst): + os.makedirs(os.path.dirname(dst), exist_ok=True) + os.symlink(src, dst, target_is_directory=True) + + run_disaggregated_cancel_test(disaggregated_example_root, + "cancel_stress_test", + env=llm_venv._new_env, + cwd=llm_venv.get_working_directory(), + num_bursts=5, + requests_per_burst=32) + + +@pytest.mark.skip_less_device(8) +@skip_pre_blackwell +@pytest.mark.parametrize("model_path", ['DeepSeek-V3-0324-FP4']) +def test_disaggregated_cancel_large_context_requests_long( + disaggregated_test_root, disaggregated_example_root, llm_venv, + model_path): + """Test that disaggregated server handles request cancellations gracefully. + + This test sends bursts of requests with large contexts and cancels them + during prefill to stress test resource cleanup. + """ + model_dir = f"{llm_models_root()}/{model_path}" + src_dst_dict = { + model_dir: f"{llm_venv.get_working_directory()}/{model_path}", + } + for src, dst in src_dst_dict.items(): + if not os.path.islink(dst): + os.makedirs(os.path.dirname(dst), exist_ok=True) + os.symlink(src, dst, target_is_directory=True) + + run_disaggregated_cancel_test(disaggregated_example_root, + "cancel_stress_test_large", + env=llm_venv._new_env, + cwd=llm_venv.get_working_directory(), + num_bursts=1000, + requests_per_burst=32) diff --git a/tests/integration/test_lists/test-db/l0_dgx_h100.yml b/tests/integration/test_lists/test-db/l0_dgx_h100.yml index a598906d4e5..85b018c1894 100644 --- a/tests/integration/test_lists/test-db/l0_dgx_h100.yml +++ b/tests/integration/test_lists/test-db/l0_dgx_h100.yml @@ -43,6 +43,7 @@ l0_dgx_h100: - accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[True-True-False] - accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[True-True-True] - unittest/llmapi/apps/test_disagg_serving_perf_metrics.py + - disaggregated/test_disaggregated.py::test_disaggregated_cancel_large_context_requests[DeepSeek-V3-Lite-bf16] # ------------- AutoDeploy tests --------------- - accuracy/test_llm_api_autodeploy.py::TestLlama3_1_8B::test_auto_dtype[False-2] # llmapi