Skip to content
Open
16 changes: 14 additions & 2 deletions cpp/include/tensorrt_llm/batch_manager/cacheTransceiver.h
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,14 @@ class CacheTransceiverFactory
std::optional<executor::CacheTransceiverConfig> cacheTransceiverConfig = std::nullopt);
};

struct RequestStatuses
{
/// Requests that have completed their transfer successfully.
std::unordered_set<LlmRequest::RequestIdType> completedRequestIds;
/// Requests that have encountered an error during their transfer.
std::unordered_set<LlmRequest::RequestIdType> errorRequestIds;
};

class BaseCacheTransceiver
{
public:
Expand All @@ -202,7 +210,10 @@ class BaseCacheTransceiver
virtual void requestAndReceiveSync(LlmRequest* llmRequest) = 0;
virtual void requestAndReceiveAsync(LlmRequest* llmRequest) = 0;

virtual void checkContextTransferStatus(std::optional<int> const& atLeastRequestNum = std::nullopt) = 0;
/// Check all requests transferring context, and return the requests that have completed or encountered an error.
virtual RequestStatuses checkContextTransferStatus(
std::optional<int> const& atLeastRequestNum = std::nullopt, bool markComplete = false)
= 0;

virtual void checkGenTransferStatus(std::optional<int> const& atLeastRequestNum = std::nullopt) = 0;

Expand Down Expand Up @@ -243,7 +254,8 @@ class CacheTransceiver : public BaseCacheTransceiver
void requestAndReceiveSync(LlmRequest* llmRequest) override;
void requestAndReceiveAsync(LlmRequest* llmRequest) override;

void checkContextTransferStatus(std::optional<int> const& atLeastRequestNum = std::nullopt) override;
RequestStatuses checkContextTransferStatus(
std::optional<int> const& atLeastRequestNum = std::nullopt, bool markComplete = false) override;

void checkGenTransferStatus(std::optional<int> const& atLeastRequestNum = std::nullopt) override;

Expand Down
15 changes: 13 additions & 2 deletions cpp/tensorrt_llm/batch_manager/cacheTransceiver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -427,7 +427,8 @@ void updateKVCacheTransferBW(std::shared_ptr<CacheTransceiverComm> const& mComm,
}
}

void CacheTransceiver::checkContextTransferStatus(std::optional<int> const& atLeastRequestNum)
RequestStatuses CacheTransceiver::checkContextTransferStatus(
std::optional<int> const& atLeastRequestNum, bool markComplete)
{
bool blockAll = !atLeastRequestNum.has_value();
std::optional<int> senderFutureTimeoutMs = std::nullopt;
Expand Down Expand Up @@ -486,6 +487,8 @@ void CacheTransceiver::checkContextTransferStatus(std::optional<int> const& atLe
toCompleteIdSet.insert(request->mRequestId);
}

RequestStatuses requestsStatus{};

// Complete all the requests in toCompleteIdSet
for (auto it = mSenderFutures.begin(); it != mSenderFutures.end();)
{
Expand All @@ -499,7 +502,11 @@ void CacheTransceiver::checkContextTransferStatus(std::optional<int> const& atLe
if (status == std::future_status::ready || !senderFutureTimeoutMs.has_value())
{
future.get();
request->setState(LlmRequestState::kDISAGG_CONTEXT_COMPLETE);
requestsStatus.completedRequestIds.insert(request->mRequestId);
if (markComplete)
{
request->setState(LlmRequestState::kDISAGG_CONTEXT_COMPLETE);
}
it = mSenderFutures.erase(it);
}
else if (status == std::future_status::timeout)
Expand All @@ -514,6 +521,7 @@ void CacheTransceiver::checkContextTransferStatus(std::optional<int> const& atLe
"Future returned unexpected status for request %ld. Marking as error", request->mRequestId);

request->setState(LlmRequestState::kDISAGG_TRANS_ERROR);
requestsStatus.errorRequestIds.insert(request->mRequestId);
it = mSenderFutures.erase(it);
}
}
Expand All @@ -522,6 +530,7 @@ void CacheTransceiver::checkContextTransferStatus(std::optional<int> const& atLe
TLLM_LOG_ERROR(
"Error occurred during context transfer for request %ld: %s", request->mRequestId, e.what());
request->setState(LlmRequestState::kDISAGG_TRANS_ERROR);
requestsStatus.errorRequestIds.insert(request->mRequestId);
it = mSenderFutures.erase(it);
}
}
Expand All @@ -530,6 +539,8 @@ void CacheTransceiver::checkContextTransferStatus(std::optional<int> const& atLe
++it;
}
}

return requestsStatus;
}

void CacheTransceiver::checkGenTransferStatus(std::optional<int> const& atLeastRequestNum)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -503,7 +503,7 @@ TrtGptModelInflightBatching::~TrtGptModelInflightBatching()
{
if (mCacheTransceiver)
{
mCacheTransceiver->checkContextTransferStatus(true);
mCacheTransceiver->checkContextTransferStatus(1, true);
TLLM_CHECK_WITH_INFO(mCacheTransceiver->checkGenTransferComplete(), "Generation transfer not complete");
}
if (mAsyncSendWaitThread)
Expand Down Expand Up @@ -932,7 +932,7 @@ void TrtGptModelInflightBatching::forwardSync()
}
if (mCacheTransceiver)
{
mCacheTransceiver->checkContextTransferStatus(0);
mCacheTransceiver->checkContextTransferStatus(0, true);
}
++mIterCounter;

Expand Down Expand Up @@ -1025,7 +1025,7 @@ void TrtGptModelInflightBatching::forwardAsync(RequestList const& activeRequests
mIterCounter);
if (mCacheTransceiver)
{
mCacheTransceiver->checkContextTransferStatus(1);
mCacheTransceiver->checkContextTransferStatus(1, true);
// will free kvCache in next iteration.
}
}
Expand Down
24 changes: 20 additions & 4 deletions cpp/tensorrt_llm/nanobind/batch_manager/cacheTransceiver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,10 @@ class PyCacheTransceiver : public tb::BaseCacheTransceiver
NB_OVERRIDE_PURE(requestAndReceiveAsync, llmRequest);
}

void checkContextTransferStatus(std::optional<int> const& atLeastRequestNum = std::nullopt) override
tb::RequestStatuses checkContextTransferStatus(
std::optional<int> const& atLeastRequestNum = std::nullopt, bool markComplete = false) override
{
NB_OVERRIDE_PURE(checkContextTransferStatus, atLeastRequestNum);
NB_OVERRIDE_PURE(checkContextTransferStatus, atLeastRequestNum, markComplete);
}

void checkGenTransferStatus(std::optional<int> const& atLeastRequestNum = std::nullopt) override
Expand All @@ -88,8 +89,23 @@ void tb::CacheTransceiverBindings::initBindings(nb::module_& m)
.def("respond_and_send_async", &BaseCacheTransceiver::respondAndSendAsync)
.def("request_and_receive_sync", &BaseCacheTransceiver::requestAndReceiveSync)
.def("request_and_receive_async", &BaseCacheTransceiver::requestAndReceiveAsync)
.def("check_context_transfer_status", &BaseCacheTransceiver::checkContextTransferStatus,
nb::call_guard<nb::gil_scoped_release>())
.def(
"check_context_transfer_status",
[](tb::BaseCacheTransceiver& self, std::optional<int> const& atLeastRequestNum, bool markComplete = false)
{
RequestStatuses result;
{
nb::gil_scoped_release release;
result = self.checkContextTransferStatus(atLeastRequestNum, markComplete);
}

auto completedRequestIds
= std::vector<int64_t>(result.completedRequestIds.begin(), result.completedRequestIds.end());
auto errorRequestIds
= std::vector<int64_t>(result.errorRequestIds.begin(), result.errorRequestIds.end());
return nb::make_tuple(completedRequestIds, errorRequestIds);
},
nb::arg("at_least_request_num") = std::nullopt, nb::arg("mark_complete") = false)
.def("check_gen_transfer_status", &BaseCacheTransceiver::checkGenTransferStatus,
nb::call_guard<nb::gil_scoped_release>())
.def("check_gen_transfer_complete", &BaseCacheTransceiver::checkGenTransferComplete)
Expand Down
27 changes: 23 additions & 4 deletions cpp/tensorrt_llm/pybind/batch_manager/cacheTransceiver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,13 @@ class PyCacheTransceiver : public tb::BaseCacheTransceiver
PYBIND11_OVERLOAD_PURE(void, tb::BaseCacheTransceiver, requestAndReceiveAsync, llmRequest);
}

void checkContextTransferStatus(std::optional<int> const& atLeastRequestNum = std::nullopt) override
using RequestStatuses = tb::RequestStatuses;

RequestStatuses checkContextTransferStatus(
std::optional<int> const& atLeastRequestNum = std::nullopt, bool markComplete = false) override
{
PYBIND11_OVERLOAD_PURE(void, tb::BaseCacheTransceiver, checkContextTransferStatus, atLeastRequestNum);
PYBIND11_OVERLOAD_PURE(
RequestStatuses, tb::BaseCacheTransceiver, checkContextTransferStatus, atLeastRequestNum, markComplete);
}

void checkGenTransferStatus(std::optional<int> const& atLeastRequestNum = std::nullopt) override
Expand All @@ -84,8 +88,23 @@ void tb::CacheTransceiverBindings::initBindings(py::module_& m)
.def("respond_and_send_async", &BaseCacheTransceiver::respondAndSendAsync)
.def("request_and_receive_sync", &BaseCacheTransceiver::requestAndReceiveSync)
.def("request_and_receive_async", &BaseCacheTransceiver::requestAndReceiveAsync)
.def("check_context_transfer_status", &BaseCacheTransceiver::checkContextTransferStatus,
py::call_guard<py::gil_scoped_release>())
.def(
"check_context_transfer_status",
[](tb::BaseCacheTransceiver& self, std::optional<int> const& atLeastRequestNum, bool markComplete = false)
{
RequestStatuses result;
{
py::gil_scoped_release release;
result = self.checkContextTransferStatus(atLeastRequestNum, markComplete);
}

auto completedRequestIds
= std::vector<int64_t>(result.completedRequestIds.begin(), result.completedRequestIds.end());
auto errorRequestIds
= std::vector<int64_t>(result.errorRequestIds.begin(), result.errorRequestIds.end());
return py::make_tuple(completedRequestIds, errorRequestIds);
},
py::arg("at_least_request_num") = std::nullopt, py::arg("mark_complete") = false)
.def("check_gen_transfer_status", &BaseCacheTransceiver::checkGenTransferStatus,
py::call_guard<py::gil_scoped_release>())
.def("check_gen_transfer_complete", &BaseCacheTransceiver::checkGenTransferComplete)
Expand Down
Loading