Skip to content

Commit 655832d

Browse files
committed
[https://nvbugs/5689235][fix] Fix cancellation+chunked prefill+disagg
Signed-off-by: Iman Tabrizian <[email protected]>
1 parent 9ba1426 commit 655832d

File tree

8 files changed

+73
-27
lines changed

8 files changed

+73
-27
lines changed

cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -631,7 +631,7 @@ class WindowBlockManager
631631

632632
void replaceSharedBlock(GenerationRequest& sequence, SizeType32 blockIdx);
633633

634-
[[nodiscard]] std::optional<KVCacheBlock::IdType> storeBlocksForReuse(
634+
[[nodiscard]] std::vector<KVCacheBlock::IdType> storeBlocksForReuse(
635635
GenerationRequest& sequence, OptionalRef<LlmRequest const> llmRequest, bool pinBlocks = false);
636636

637637
void storeNewBlock(GenerationRequest& sequence, OptionalRef<LlmRequest const> llmRequest);
@@ -836,8 +836,8 @@ class WindowBlockManager
836836
//! \param blockKeys Key of each block.
837837
//! \param blockIds Id of each block.
838838
//! \param pinBlocks If true, increment ref count for blocks while storing (pin on store).
839-
//! \return Pair of (num blocks stored for reuse, id of the last block stored if any).
840-
[[nodiscard]] std::pair<SizeType32, std::optional<KVCacheBlock::IdType>> storeBlocks(
839+
//! \return Pair of (num blocks stored for reuse, vector of pinned block IDs).
840+
[[nodiscard]] std::pair<SizeType32, std::vector<KVCacheBlock::IdType>> storeBlocks(
841841
std::vector<BlockKey> const& blockKeys, std::vector<KVCacheBlock::IdType> const& blockIds,
842842
bool pinBlocks = false);
843843

@@ -869,8 +869,8 @@ class WindowBlockManager
869869

870870
[[nodiscard]] std::shared_ptr<KVCacheBlock> findBlocksInReuseTreeByBlockKey(BlockKey const& blockKey);
871871

872-
//! \brief Unpin blocks by starting from a block id and walking prev pointers.
873-
void unpinBlocksById(KVCacheBlock::IdType blockId);
872+
//! \brief Unpin blocks by block ids directly
873+
void unpinBlocksById(std::vector<KVCacheBlock::IdType> const& blockIds);
874874

875875
void initializeSequenceStorageValidity(LlmRequest::RequestIdType requestId)
876876
{
@@ -1086,7 +1086,7 @@ class BlockManager
10861086
std::optional<KVCacheBlock::IdType> releaseBlocks(
10871087
GenerationRequest& sequence, OptionalRef<LlmRequest const> llmRequest = std::nullopt, bool pinBlocks = false);
10881088

1089-
[[nodiscard]] std::optional<KVCacheBlock::IdType> storeBlocksForReuse(
1089+
[[nodiscard]] std::vector<KVCacheBlock::IdType> storeBlocksForReuse(
10901090
GenerationRequest& sequence, OptionalRef<LlmRequest const> llmRequest = std::nullopt, bool pinBlocks = false);
10911091

10921092
void schedulingReleaseBlocks(LlmRequest::RequestIdType requestId);
@@ -1095,7 +1095,7 @@ class BlockManager
10951095
/// @param sequence The generation request whose blocks should be pinned.
10961096
void pinBlocks(GenerationRequest& sequence);
10971097

1098-
void unpinBlocksById(KVCacheBlock::IdType blockId);
1098+
void unpinBlocksById(std::vector<KVCacheBlock::IdType> const& blockIds);
10991099

11001100
void releaseLastBlock(GenerationRequest& sequence, SizeType32 windowSize);
11011101

@@ -1116,7 +1116,7 @@ class BlockManager
11161116
void offloadBlock(BlockPtr const& block, SizeType32 windowSize,
11171117
executor::KvCacheTransferMode mode = executor::KvCacheTransferMode::DRAM, std::string const& directory = "");
11181118

1119-
[[nodiscard]] std::pair<SizeType32, std::optional<KVCacheBlock::IdType>> storeBlocks(
1119+
[[nodiscard]] std::pair<SizeType32, std::vector<KVCacheBlock::IdType>> storeBlocks(
11201120
std::vector<BlockKey> const& blockKeys, std::vector<KVCacheBlock::IdType> const& blockIds,
11211121
SizeType32 windowSize, bool pinBlocks = false)
11221122
{
@@ -1567,7 +1567,7 @@ class BaseKVCacheManager
15671567
virtual void storeNewBlock(LlmRequest const& llmRequest) = 0;
15681568

15691569
/// \brief Store blocks for reuse for a given request id
1570-
[[nodiscard]] virtual std::optional<KVCacheBlock::IdType> storeBlocksForReuse(
1570+
[[nodiscard]] virtual std::vector<KVCacheBlock::IdType> storeBlocksForReuse(
15711571
LlmRequest::RequestIdType requestId, OptionalRef<LlmRequest const> llmRequest, bool pinBlocks = false)
15721572
= 0;
15731573

@@ -1661,7 +1661,7 @@ class BaseKVCacheManager
16611661
BlockKey const& blockKey, SizeType32 windowSize)
16621662
= 0;
16631663

1664-
virtual void unpinBlocksById(KVCacheBlock::IdType blockId) = 0;
1664+
virtual void unpinBlocksById(std::vector<KVCacheBlock::IdType> const& blockIds) = 0;
16651665
};
16661666

16671667
class KVCacheManager : public BaseKVCacheManager
@@ -1922,7 +1922,7 @@ class KVCacheManager : public BaseKVCacheManager
19221922
//! \brief Store newest blocks for reuse
19231923
void storeNewBlock(LlmRequest const& llmRequest) override;
19241924

1925-
[[nodiscard]] std::optional<KVCacheBlock::IdType> storeBlocksForReuse(
1925+
[[nodiscard]] std::vector<KVCacheBlock::IdType> storeBlocksForReuse(
19261926
LlmRequest::RequestIdType requestId, OptionalRef<LlmRequest const> llmRequest, bool pinBlocks = false) override;
19271927

19281928
[[nodiscard]] static SizeType32 getSinkBubbleLength(SizeType32 sinkTokenLen, SizeType32 tokensPerBlock);
@@ -1943,7 +1943,7 @@ class KVCacheManager : public BaseKVCacheManager
19431943

19441944
void pinBlocks(LlmRequest::RequestIdType requestId) override;
19451945

1946-
void unpinBlocksById(KVCacheBlock::IdType blockId) override;
1946+
void unpinBlocksById(std::vector<KVCacheBlock::IdType> const& blockIds) override;
19471947

19481948
std::optional<KVCacheBlock::IdType> getLastBlockId(LlmRequest::RequestIdType requestId) const override;
19491949

cpp/include/tensorrt_llm/batch_manager/llmRequest.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1667,6 +1667,12 @@ class GenericLlmRequest
16671667
[](auto reason) { return reason == executor::FinishReason::kLENGTH; });
16681668
}
16691669

1670+
[[nodiscard]] bool isFinishedDueToCancellation() const noexcept
1671+
{
1672+
return std::all_of(mFinishReasons.begin(), mFinishReasons.end(),
1673+
[](auto reason) { return reason == executor::FinishReason::kCANCELLED; });
1674+
}
1675+
16701676
[[nodiscard]] bool isTimedOut() const
16711677
{
16721678
if (!mAllottedTimeMs.has_value())

cpp/tensorrt_llm/nanobind/batch_manager/bindings.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,7 @@ void initBindings(nb::module_& m)
161161
.def("set_finished_reason", &GenLlmReq::setFinishedReason, nb::arg("finish_reason"), nb::arg("beam"))
162162
.def_prop_ro("is_finished", &GenLlmReq::isFinished)
163163
.def_prop_ro("is_finished_due_to_length", &GenLlmReq::isFinishedDueToLength)
164+
.def_prop_ro("is_finished_due_to_cancellation", &GenLlmReq::isFinishedDueToCancellation)
164165
.def_prop_rw(
165166
"context_current_position", &GenLlmReq::getContextCurrentPosition, &GenLlmReq::setContextCurrentPosition)
166167
.def_prop_ro("prepopulated_prompt_len", &GenLlmReq::getPrepopulatedPromptLen)

cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,7 @@ void initBindings(pybind11::module_& m)
165165
.def("set_finished_reason", &GenLlmReq::setFinishedReason, py::arg("finish_reason"), py::arg("beam"))
166166
.def_property_readonly("is_finished", &GenLlmReq::isFinished)
167167
.def_property_readonly("is_finished_due_to_length", &GenLlmReq::isFinishedDueToLength)
168+
.def_property_readonly("is_finished_due_to_cancellation", &GenLlmReq::isFinishedDueToCancellation)
168169
.def_property(
169170
"context_current_position", &GenLlmReq::getContextCurrentPosition, &GenLlmReq::setContextCurrentPosition)
170171
.def_property_readonly("prepopulated_prompt_len", &GenLlmReq::getPrepopulatedPromptLen)

cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -111,10 +111,10 @@ class PyKvCacheManager : public tbk::BaseKVCacheManager
111111
requestId, llmRequest, pinOnRelease);
112112
}
113113

114-
std::optional<tbk::KVCacheBlock::IdType> storeBlocksForReuse(tb::LlmRequest::RequestIdType requestId,
114+
std::vector<tbk::KVCacheBlock::IdType> storeBlocksForReuse(tb::LlmRequest::RequestIdType requestId,
115115
tensorrt_llm::common::OptionalRef<tb::LlmRequest const> llmRequest, bool pinBlocks) override
116116
{
117-
PYBIND11_OVERLOAD_PURE(std::optional<tbk::KVCacheBlock::IdType>, tbk::BaseKVCacheManager, storeBlocksForReuse,
117+
PYBIND11_OVERLOAD_PURE(std::vector<tbk::KVCacheBlock::IdType>, tbk::BaseKVCacheManager, storeBlocksForReuse,
118118
requestId, llmRequest, pinBlocks);
119119
}
120120

@@ -367,7 +367,22 @@ void tb::kv_cache_manager::KVCacheManagerBindings::initBindings(py::module_& m)
367367
py::call_guard<py::gil_scoped_release>())
368368
.def("add_token", &BaseKVCacheManager::addToken, py::call_guard<py::gil_scoped_release>())
369369
.def("add_sequence", &BaseKVCacheManager::addSequence, py::call_guard<py::gil_scoped_release>())
370-
.def("remove_sequence", &BaseKVCacheManager::removeSequence, py::call_guard<py::gil_scoped_release>())
370+
.def(
371+
"remove_sequence",
372+
[](tbk::BaseKVCacheManager& self, tb::LlmRequest::RequestIdType requestId, tb::LlmRequest const* llmRequest,
373+
bool pinOnRelease)
374+
{
375+
if (llmRequest != nullptr)
376+
{
377+
return self.removeSequence(requestId, *llmRequest, pinOnRelease);
378+
}
379+
else
380+
{
381+
return self.removeSequence(requestId, std::nullopt, pinOnRelease);
382+
}
383+
},
384+
py::arg("request_id"), py::arg("llm_request") = nullptr, py::arg("pin_on_release") = false,
385+
py::call_guard<py::gil_scoped_release>())
371386
.def("pin_blocks", &BaseKVCacheManager::pinBlocks, py::call_guard<py::gil_scoped_release>())
372387
.def("scheduling_remove_sequence", &BaseKVCacheManager::schedulingRemoveSequence,
373388
py::call_guard<py::gil_scoped_release>())

cpp/tests/unit_tests/batch_manager/kvCacheManagerTest.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4050,11 +4050,13 @@ TEST_F(KVCacheManagerTest, PinAndUnpinBlocksById)
40504050
kvCacheManager.pinBlocks(requestId);
40514051
auto lastBlockIdOpt = kvCacheManager.getLastBlockId(requestId);
40524052
ASSERT_TRUE(lastBlockIdOpt.has_value());
4053+
auto const& allBlockIds = kvCacheManager.getCacheBlockIds(requestId, maxAttentionWindow)[0];
4054+
std::vector<SizeType32> pinnedBlockIds(allBlockIds.begin(), allBlockIds.end());
40534055
(void) kvCacheManager.removeSequence(requestId, llmRequest);
40544056
auto const freeAfterRemovePinned = kvCacheManager.getNumFreeBlocks();
40554057
EXPECT_LT(freeAfterRemovePinned, totalBlocks);
40564058

4057-
kvCacheManager.unpinBlocksById(lastBlockIdOpt.value());
4059+
kvCacheManager.unpinBlocksById(pinnedBlockIds);
40584060
auto const freeAfterUnpin = kvCacheManager.getNumFreeBlocks();
40594061
EXPECT_EQ(freeAfterUnpin, totalBlocks);
40604062
}

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1071,7 +1071,8 @@ def _executor_loop_pp(self):
10711071
for req in previous_batch.scheduled_ctx_reqs:
10721072
if req.is_context_only_request and (
10731073
req.is_context_finished
1074-
or req.is_finished_due_to_length):
1074+
or req.is_finished_due_to_length
1075+
) and not req.is_finished_due_to_cancellation:
10751076
block_id = self.kv_cache_manager.store_blocks_for_reuse(
10761077
req, True)
10771078
self.ctx_in_transmission_requests[
@@ -1340,7 +1341,8 @@ def _executor_loop(self):
13401341
for req in scheduled_batch.context_requests:
13411342
if req.is_context_only_request and (
13421343
req.is_context_finished
1343-
or req.is_finished_due_to_length):
1344+
or req.is_finished_due_to_length
1345+
) and not req.is_finished_due_to_cancellation:
13441346
block_id = self.kv_cache_manager.store_blocks_for_reuse(
13451347
req, True)
13461348
self.ctx_in_transmission_requests[
@@ -1567,7 +1569,8 @@ def _executor_loop_overlap(self):
15671569
for req in self.previous_batch.sample_state.scheduled_requests.context_requests:
15681570
if req.is_context_only_request and (
15691571
req.is_context_finished
1570-
or req.is_finished_due_to_length):
1572+
or req.is_finished_due_to_length
1573+
) and not req.is_finished_due_to_cancellation:
15711574
block_id = self.kv_cache_manager.store_blocks_for_reuse(
15721575
req, True)
15731576
self.ctx_in_transmission_requests[
@@ -2076,8 +2079,9 @@ def _send_disagg_ctx_cache(self, scheduled_ctx_requests):
20762079
if (scheduled_ctx_requests is None or len(scheduled_ctx_requests) == 0):
20772080
return []
20782081
for req in scheduled_ctx_requests:
2079-
if req.is_context_only_request and (req.is_context_finished or
2080-
req.is_finished_due_to_length):
2082+
if req.is_context_only_request and (
2083+
req.is_context_finished or req.is_finished_due_to_length
2084+
) and not req.is_finished_due_to_cancellation:
20812085
self.kv_cache_transceiver.respond_and_send_async(req)
20822086
for resource_mgr_type in (
20832087
ResourceManagerType.SEQ_SLOT_MANAGER,
@@ -2327,7 +2331,12 @@ def _do_terminate_request(self, request: LlmRequest):
23272331
self.ctx_in_transmission_requests[request.py_request_id] = (
23282332
(request, block_id, self.ctx_in_transmission_counter))
23292333

2330-
self.resource_manager.free_resources(request)
2334+
store_blocks_for_reuse = not (self.block_reuse_enabled
2335+
and not self.kv_cache_manager.is_vswa
2336+
and self.kv_cache_transceiver
2337+
and request.is_context_only_request)
2338+
self.resource_manager.free_resources(
2339+
request, store_blocks_for_reuse=store_blocks_for_reuse)
23312340

23322341
if self.gather_all_responses or self.dist.rank == 0:
23332342
self.result_wait_queues.pop(request.py_request_id, None)

tensorrt_llm/_torch/pyexecutor/resource_manager.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -663,8 +663,13 @@ def update_kv_cache_draft_token_location(self,
663663
None,
664664
)
665665

666-
def free_resources(self, request: LlmRequest, pin_on_release: bool = False):
667-
return self.impl.remove_sequence(request.py_request_id, request,
666+
def free_resources(self,
667+
request: LlmRequest,
668+
pin_on_release: bool = False,
669+
store_blocks_for_reuse: bool = True):
670+
# When store_blocks_for_reuse is False, pass None to prevent block storage
671+
llm_request = request if store_blocks_for_reuse else None
672+
return self.impl.remove_sequence(request.py_request_id, llm_request,
668673
pin_on_release)
669674

670675
def store_blocks_for_reuse(self,
@@ -1407,10 +1412,17 @@ def update_resources(self,
14071412
else:
14081413
resource_manager.update_resources(scheduled_batch)
14091414

1410-
def free_resources(self, request: LlmRequest):
1411-
for _, resource_manager in reversed(self.resource_managers.items()):
1415+
def free_resources(self,
1416+
request: LlmRequest,
1417+
store_blocks_for_reuse: bool = True):
1418+
for resource_type, resource_manager in reversed(
1419+
self.resource_managers.items()):
14121420
if hasattr(resource_manager, "free_resources"):
1413-
resource_manager.free_resources(request)
1421+
if resource_type == ResourceManagerType.KV_CACHE_MANAGER:
1422+
resource_manager.free_resources(
1423+
request, store_blocks_for_reuse=store_blocks_for_reuse)
1424+
else:
1425+
resource_manager.free_resources(request)
14141426

14151427
def reorder_pipeline(self,
14161428
resource_manager_list: list[ResourceManagerType]):

0 commit comments

Comments
 (0)