Skip to content
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions cpp/include/tensorrt_llm/batch_manager/cacheTransceiver.h
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,14 @@ class CacheTransceiverFactory
std::optional<executor::CacheTransceiverConfig> cacheTransceiverConfig = std::nullopt);
};

struct RequestStatuses
{
/// Requests that have completed their transfer successfully.
std::unordered_set<LlmRequest::RequestIdType> completedRequestIds;
/// Requests that have encountered an error during their transfer.
std::unordered_set<LlmRequest::RequestIdType> errorRequestIds;
};

class BaseCacheTransceiver
{
public:
Expand All @@ -202,7 +210,10 @@ class BaseCacheTransceiver
virtual void requestAndReceiveSync(LlmRequest* llmRequest) = 0;
virtual void requestAndReceiveAsync(LlmRequest* llmRequest) = 0;

virtual void checkContextTransferStatus(std::optional<int> const& atLeastRequestNum = std::nullopt) = 0;
/// Check all requests transferring context, and return the requests that have completed or encountered an error.
virtual RequestStatuses checkContextTransferStatus(
std::optional<int> const& atLeastRequestNum = std::nullopt, bool markComplete = false)
= 0;

virtual void checkGenTransferStatus(std::optional<int> const& atLeastRequestNum = std::nullopt) = 0;

Expand Down Expand Up @@ -243,7 +254,8 @@ class CacheTransceiver : public BaseCacheTransceiver
void requestAndReceiveSync(LlmRequest* llmRequest) override;
void requestAndReceiveAsync(LlmRequest* llmRequest) override;

void checkContextTransferStatus(std::optional<int> const& atLeastRequestNum = std::nullopt) override;
RequestStatuses checkContextTransferStatus(
std::optional<int> const& atLeastRequestNum = std::nullopt, bool markComplete = false) override;

void checkGenTransferStatus(std::optional<int> const& atLeastRequestNum = std::nullopt) override;

Expand Down
15 changes: 13 additions & 2 deletions cpp/tensorrt_llm/batch_manager/cacheTransceiver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -427,7 +427,8 @@ void updateKVCacheTransferBW(std::shared_ptr<CacheTransceiverComm> const& mComm,
}
}

void CacheTransceiver::checkContextTransferStatus(std::optional<int> const& atLeastRequestNum)
RequestStatuses CacheTransceiver::checkContextTransferStatus(
std::optional<int> const& atLeastRequestNum, bool markComplete)
{
bool blockAll = !atLeastRequestNum.has_value();
std::optional<int> senderFutureTimeoutMs = std::nullopt;
Expand Down Expand Up @@ -486,6 +487,8 @@ void CacheTransceiver::checkContextTransferStatus(std::optional<int> const& atLe
toCompleteIdSet.insert(request->mRequestId);
}

RequestStatuses requestsStatus{};

// Complete all the requests in toCompleteIdSet
for (auto it = mSenderFutures.begin(); it != mSenderFutures.end();)
{
Expand All @@ -499,7 +502,11 @@ void CacheTransceiver::checkContextTransferStatus(std::optional<int> const& atLe
if (status == std::future_status::ready || !senderFutureTimeoutMs.has_value())
{
future.get();
request->setState(LlmRequestState::kDISAGG_CONTEXT_COMPLETE);
requestsStatus.completedRequestIds.insert(request->mRequestId);
if (markComplete)
{
request->setState(LlmRequestState::kDISAGG_CONTEXT_COMPLETE);
}
it = mSenderFutures.erase(it);
}
else if (status == std::future_status::timeout)
Expand All @@ -514,6 +521,7 @@ void CacheTransceiver::checkContextTransferStatus(std::optional<int> const& atLe
"Future returned unexpected status for request %ld. Marking as error", request->mRequestId);

request->setState(LlmRequestState::kDISAGG_TRANS_ERROR);
requestsStatus.errorRequestIds.insert(request->mRequestId);
it = mSenderFutures.erase(it);
}
}
Expand All @@ -522,6 +530,7 @@ void CacheTransceiver::checkContextTransferStatus(std::optional<int> const& atLe
TLLM_LOG_ERROR(
"Error occurred during context transfer for request %ld: %s", request->mRequestId, e.what());
request->setState(LlmRequestState::kDISAGG_TRANS_ERROR);
requestsStatus.errorRequestIds.insert(request->mRequestId);
it = mSenderFutures.erase(it);
}
}
Expand All @@ -530,6 +539,8 @@ void CacheTransceiver::checkContextTransferStatus(std::optional<int> const& atLe
++it;
}
}

return requestsStatus;
}

void CacheTransceiver::checkGenTransferStatus(std::optional<int> const& atLeastRequestNum)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -503,7 +503,7 @@ TrtGptModelInflightBatching::~TrtGptModelInflightBatching()
{
if (mCacheTransceiver)
{
mCacheTransceiver->checkContextTransferStatus(true);
mCacheTransceiver->checkContextTransferStatus(1, true);
TLLM_CHECK_WITH_INFO(mCacheTransceiver->checkGenTransferComplete(), "Generation transfer not complete");
}
if (mAsyncSendWaitThread)
Expand Down Expand Up @@ -932,7 +932,7 @@ void TrtGptModelInflightBatching::forwardSync()
}
if (mCacheTransceiver)
{
mCacheTransceiver->checkContextTransferStatus(0);
mCacheTransceiver->checkContextTransferStatus(0, true);
}
++mIterCounter;

Expand Down Expand Up @@ -1025,7 +1025,7 @@ void TrtGptModelInflightBatching::forwardAsync(RequestList const& activeRequests
mIterCounter);
if (mCacheTransceiver)
{
mCacheTransceiver->checkContextTransferStatus(1);
mCacheTransceiver->checkContextTransferStatus(1, true);
// will free kvCache in next iteration.
}
}
Expand Down
24 changes: 20 additions & 4 deletions cpp/tensorrt_llm/nanobind/batch_manager/cacheTransceiver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,10 @@ class PyCacheTransceiver : public tb::BaseCacheTransceiver
NB_OVERRIDE_PURE(requestAndReceiveAsync, llmRequest);
}

void checkContextTransferStatus(std::optional<int> const& atLeastRequestNum = std::nullopt) override
tb::RequestStatuses checkContextTransferStatus(
std::optional<int> const& atLeastRequestNum = std::nullopt, bool markComplete = false) override
{
NB_OVERRIDE_PURE(checkContextTransferStatus, atLeastRequestNum);
NB_OVERRIDE_PURE(checkContextTransferStatus, atLeastRequestNum, markComplete);
}

void checkGenTransferStatus(std::optional<int> const& atLeastRequestNum = std::nullopt) override
Expand All @@ -88,8 +89,23 @@ void tb::CacheTransceiverBindings::initBindings(nb::module_& m)
.def("respond_and_send_async", &BaseCacheTransceiver::respondAndSendAsync)
.def("request_and_receive_sync", &BaseCacheTransceiver::requestAndReceiveSync)
.def("request_and_receive_async", &BaseCacheTransceiver::requestAndReceiveAsync)
.def("check_context_transfer_status", &BaseCacheTransceiver::checkContextTransferStatus,
nb::call_guard<nb::gil_scoped_release>())
.def(
"check_context_transfer_status",
[](tb::BaseCacheTransceiver& self, std::optional<int> const& atLeastRequestNum, bool markComplete = false)
{
RequestStatuses result;
{
nb::gil_scoped_release release;
result = self.checkContextTransferStatus(atLeastRequestNum, markComplete);
}

auto completedRequestIds
= std::vector<int64_t>(result.completedRequestIds.begin(), result.completedRequestIds.end());
auto errorRequestIds
= std::vector<int64_t>(result.errorRequestIds.begin(), result.errorRequestIds.end());
return nb::make_tuple(completedRequestIds, errorRequestIds);
},
nb::arg("at_least_request_num") = std::nullopt, nb::arg("mark_complete") = false)
.def("check_gen_transfer_status", &BaseCacheTransceiver::checkGenTransferStatus,
nb::call_guard<nb::gil_scoped_release>())
.def("check_gen_transfer_complete", &BaseCacheTransceiver::checkGenTransferComplete)
Expand Down
27 changes: 23 additions & 4 deletions cpp/tensorrt_llm/pybind/batch_manager/cacheTransceiver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,13 @@ class PyCacheTransceiver : public tb::BaseCacheTransceiver
PYBIND11_OVERLOAD_PURE(void, tb::BaseCacheTransceiver, requestAndReceiveAsync, llmRequest);
}

void checkContextTransferStatus(std::optional<int> const& atLeastRequestNum = std::nullopt) override
using RequestStatuses = tb::RequestStatuses;

RequestStatuses checkContextTransferStatus(
std::optional<int> const& atLeastRequestNum = std::nullopt, bool markComplete = false) override
{
PYBIND11_OVERLOAD_PURE(void, tb::BaseCacheTransceiver, checkContextTransferStatus, atLeastRequestNum);
PYBIND11_OVERLOAD_PURE(
RequestStatuses, tb::BaseCacheTransceiver, checkContextTransferStatus, atLeastRequestNum, markComplete);
}

void checkGenTransferStatus(std::optional<int> const& atLeastRequestNum = std::nullopt) override
Expand All @@ -84,8 +88,23 @@ void tb::CacheTransceiverBindings::initBindings(py::module_& m)
.def("respond_and_send_async", &BaseCacheTransceiver::respondAndSendAsync)
.def("request_and_receive_sync", &BaseCacheTransceiver::requestAndReceiveSync)
.def("request_and_receive_async", &BaseCacheTransceiver::requestAndReceiveAsync)
.def("check_context_transfer_status", &BaseCacheTransceiver::checkContextTransferStatus,
py::call_guard<py::gil_scoped_release>())
.def(
"check_context_transfer_status",
[](tb::BaseCacheTransceiver& self, std::optional<int> const& atLeastRequestNum, bool markComplete = false)
{
RequestStatuses result;
{
py::gil_scoped_release release;
result = self.checkContextTransferStatus(atLeastRequestNum, markComplete);
}

auto completedRequestIds
= std::vector<int64_t>(result.completedRequestIds.begin(), result.completedRequestIds.end());
auto errorRequestIds
= std::vector<int64_t>(result.errorRequestIds.begin(), result.errorRequestIds.end());
return py::make_tuple(completedRequestIds, errorRequestIds);
},
py::arg("at_least_request_num") = std::nullopt, py::arg("mark_complete") = false)
.def("check_gen_transfer_status", &BaseCacheTransceiver::checkGenTransferStatus,
py::call_guard<py::gil_scoped_release>())
.def("check_gen_transfer_complete", &BaseCacheTransceiver::checkGenTransferComplete)
Expand Down
Loading