Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions cpp/include/tensorrt_llm/executor/executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -1465,16 +1465,19 @@ class CacheTransceiverConfig
NIXL = 3
};
explicit CacheTransceiverConfig(std::optional<BackendType> backendType = std::nullopt,
std::optional<size_t> maxNumTokens = std::nullopt, std::optional<int> kvTransferTimeoutMs = std::nullopt);
std::optional<size_t> maxNumTokens = std::nullopt, std::optional<int> kvTransferTimeoutMs = std::nullopt,
std::optional<int> kvTransferSenderFutureTimeoutMs = std::nullopt);

bool operator==(CacheTransceiverConfig const& other) const;
void setBackendType(std::optional<BackendType> backendType);
void setMaxTokensInBuffer(std::optional<size_t> maxTokensInBuffer);
void setKvTransferTimeoutMs(std::optional<int> kvTransferTimeoutMs);
void setKvTransferSenderFutureTimeoutMs(std::optional<int> kvTransferSenderFutureTimeoutMs);

[[nodiscard]] std::optional<int> getKvTransferTimeoutMs() const;
[[nodiscard]] std::optional<size_t> getMaxTokensInBuffer() const;
[[nodiscard]] std::optional<BackendType> getBackendType() const;
[[nodiscard]] std::optional<int> getKvTransferTimeoutMs() const;
[[nodiscard]] std::optional<int> getKvTransferSenderFutureTimeoutMs() const;

private:
std::optional<BackendType> mBackendType;
Expand All @@ -1483,6 +1486,9 @@ class CacheTransceiverConfig
/// transfer may be degraded.
std::optional<size_t> mMaxTokensInBuffer;
std::optional<int> mKvTransferTimeoutMs;
// @brief Timeout in milliseconds to wait for the sender future to be ready when scheduled batch size is 0. This
// allows the request to be eventually cancelled by the user or because of kv_transfer_timeout_ms
std::optional<int> mKvTransferSenderFutureTimeoutMs;
};

/// @brief Configuration class for the model executor
Expand Down
33 changes: 30 additions & 3 deletions cpp/tensorrt_llm/batch_manager/cacheTransceiver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -419,6 +419,13 @@ void updateKVCacheTransferBW(std::shared_ptr<CacheTransceiverComm> const& mComm,
void CacheTransceiver::checkContextTransferStatus(std::optional<int> const& atLeastRequestNum)
{
bool blockAll = !atLeastRequestNum.has_value();
std::optional<int> senderFutureTimeoutMs = std::nullopt;
// If blockAll is true, we want to block and not use a timeout
if (!blockAll && mCacheTransceiverConfig.has_value())
{
senderFutureTimeoutMs = mCacheTransceiverConfig->getKvTransferSenderFutureTimeoutMs();
}

auto syncComm = mCacheState->getParallelConfig().mEnableAttentionDP ? mGroupTPInDPComm : mGroupTensorParaComm;
std::vector<LlmRequest::RequestIdType> contextCompleteRequestIds;
for (auto&& [request, future] : mSenderFutures)
Expand Down Expand Up @@ -476,16 +483,36 @@ void CacheTransceiver::checkContextTransferStatus(std::optional<int> const& atLe
{
try
{
future.get();
request->setState(LlmRequestState::kDISAGG_CONTEXT_COMPLETE);
// Wait for up to a specified timeout
auto status = future.wait_for(std::chrono::milliseconds(senderFutureTimeoutMs.value_or(0)));
if (status == std::future_status::ready || !senderFutureTimeoutMs.has_value())
{
future.get();
request->setState(LlmRequestState::kDISAGG_CONTEXT_COMPLETE);
it = mSenderFutures.erase(it);
}
else if (status == std::future_status::timeout)
{
TLLM_LOG_WARNING("Timed out waiting for context transfer for request %ld after %d milliseconds.",
request->mRequestId, senderFutureTimeoutMs.value());
++it;
}
else
{
TLLM_LOG_ERROR(
"Future returned unexpected status for request %ld. Marking as error", request->mRequestId);

request->setState(LlmRequestState::kDISAGG_TRANS_ERROR);
it = mSenderFutures.erase(it);
}
}
catch (std::exception const& e)
{
TLLM_LOG_ERROR(
"Error occurred during context transfer for request %ld: %s", request->mRequestId, e.what());
request->setState(LlmRequestState::kDISAGG_TRANS_ERROR);
it = mSenderFutures.erase(it);
}
it = mSenderFutures.erase(it);
}
else
{
Expand Down
19 changes: 17 additions & 2 deletions cpp/tensorrt_llm/executor/cacheTransceiverConfig.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,13 @@
namespace tensorrt_llm::executor
{

CacheTransceiverConfig::CacheTransceiverConfig(
std::optional<BackendType> backendType, std::optional<size_t> maxNumTokens, std::optional<int> kvTransferTimeoutMs)
CacheTransceiverConfig::CacheTransceiverConfig(std::optional<BackendType> backendType,
std::optional<size_t> maxNumTokens, std::optional<int> kvTransferTimeoutMs,
std::optional<int> kvTransferSenderFutureTimeoutMs)
: mBackendType(backendType)
, mMaxTokensInBuffer(maxNumTokens)
, mKvTransferTimeoutMs(kvTransferTimeoutMs)
, mKvTransferSenderFutureTimeoutMs(kvTransferSenderFutureTimeoutMs)
{
}

Expand Down Expand Up @@ -54,6 +56,15 @@ void CacheTransceiverConfig::setKvTransferTimeoutMs(std::optional<int> kvTransfe
mKvTransferTimeoutMs = kvTransferTimeoutMs;
}

void CacheTransceiverConfig::setKvTransferSenderFutureTimeoutMs(std::optional<int> kvTransferSenderFutureTimeoutMs)
{
if (kvTransferSenderFutureTimeoutMs.has_value() && kvTransferSenderFutureTimeoutMs.value() <= 0)
{
TLLM_THROW("kvTransferSenderFutureTimeoutMs must be positive");
}
mKvTransferSenderFutureTimeoutMs = kvTransferSenderFutureTimeoutMs;
}

std::optional<CacheTransceiverConfig::BackendType> CacheTransceiverConfig::getBackendType() const
{
return mBackendType;
Expand All @@ -69,4 +80,8 @@ std::optional<int> CacheTransceiverConfig::getKvTransferTimeoutMs() const
return mKvTransferTimeoutMs;
}

std::optional<int> CacheTransceiverConfig::getKvTransferSenderFutureTimeoutMs() const
{
return mKvTransferSenderFutureTimeoutMs;
}
} // namespace tensorrt_llm::executor
8 changes: 7 additions & 1 deletion cpp/tensorrt_llm/executor/serialization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1290,20 +1290,26 @@ CacheTransceiverConfig Serialization::deserializeCacheTransceiverConfig(std::ist
{
auto backendType = su::deserialize<std::optional<CacheTransceiverConfig::BackendType>>(is);
auto maxTokensInBuffer = su::deserialize<std::optional<size_t>>(is);
return CacheTransceiverConfig{backendType, maxTokensInBuffer};
auto kvTransferTimeoutMs = su::deserialize<std::optional<int>>(is);
auto kvTransferSenderFutureTimeoutMs = su::deserialize<std::optional<int>>(is);
return CacheTransceiverConfig{backendType, maxTokensInBuffer, kvTransferTimeoutMs, kvTransferSenderFutureTimeoutMs};
}

void Serialization::serialize(CacheTransceiverConfig const& cacheTransceiverConfig, std::ostream& os)
{
su::serialize(cacheTransceiverConfig.getBackendType(), os);
su::serialize(cacheTransceiverConfig.getMaxTokensInBuffer(), os);
su::serialize(cacheTransceiverConfig.getKvTransferTimeoutMs(), os);
su::serialize(cacheTransceiverConfig.getKvTransferSenderFutureTimeoutMs(), os);
}

size_t Serialization::serializedSize(CacheTransceiverConfig const& cacheTransceiverConfig)
{
size_t totalSize = 0;
totalSize += su::serializedSize(cacheTransceiverConfig.getBackendType());
totalSize += su::serializedSize(cacheTransceiverConfig.getMaxTokensInBuffer());
totalSize += su::serializedSize(cacheTransceiverConfig.getKvTransferTimeoutMs());
totalSize += su::serializedSize(cacheTransceiverConfig.getKvTransferSenderFutureTimeoutMs());
return totalSize;
}

Expand Down
8 changes: 6 additions & 2 deletions cpp/tensorrt_llm/nanobind/executor/executorConfig.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -465,15 +465,19 @@ void initConfigBindings(nb::module_& m)

nb::class_<tle::CacheTransceiverConfig>(m, "CacheTransceiverConfig")
.def(nb::init<std::optional<tle::CacheTransceiverConfig::BackendType>, std::optional<size_t>,
std::optional<int>>(),
std::optional<int>, std::optional<int>>(),
nb::arg("backend") = std::nullopt, nb::arg("max_tokens_in_buffer") = std::nullopt,
nb::arg("kv_transfer_timeout_ms") = std::nullopt)
nb::arg("kv_transfer_timeout_ms") = std::nullopt,
nb::arg("kv_transfer_sender_future_timeout_ms") = std::nullopt)
.def_prop_rw(
"backend", &tle::CacheTransceiverConfig::getBackendType, &tle::CacheTransceiverConfig::setBackendType)
.def_prop_rw("max_tokens_in_buffer", &tle::CacheTransceiverConfig::getMaxTokensInBuffer,
&tle::CacheTransceiverConfig::setMaxTokensInBuffer)
.def_prop_rw("kv_transfer_timeout_ms", &tle::CacheTransceiverConfig::getKvTransferTimeoutMs,
&tle::CacheTransceiverConfig::setKvTransferTimeoutMs)
.def_prop_rw("kv_transfer_sender_future_timeout_ms",
&tle::CacheTransceiverConfig::getKvTransferSenderFutureTimeoutMs,
&tle::CacheTransceiverConfig::setKvTransferSenderFutureTimeoutMs)
.def("__getstate__", cacheTransceiverConfigGetstate)
.def("__setstate__", cacheTransceiverConfigSetstate);

Expand Down
8 changes: 6 additions & 2 deletions cpp/tensorrt_llm/pybind/executor/executorConfig.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -447,15 +447,19 @@ void initConfigBindings(pybind11::module_& m)

py::class_<tle::CacheTransceiverConfig>(m, "CacheTransceiverConfig")
.def(py::init<std::optional<tle::CacheTransceiverConfig::BackendType>, std::optional<size_t>,
std::optional<int>>(),
std::optional<int>, std::optional<int>>(),
py::arg("backend") = std::nullopt, py::arg("max_tokens_in_buffer") = std::nullopt,
py::arg("kv_transfer_timeout_ms") = std::nullopt)
py::arg("kv_transfer_timeout_ms") = std::nullopt,
py::arg("kv_transfer_sender_future_timeout_ms") = std::nullopt)
.def_property(
"backend", &tle::CacheTransceiverConfig::getBackendType, &tle::CacheTransceiverConfig::setBackendType)
.def_property("max_tokens_in_buffer", &tle::CacheTransceiverConfig::getMaxTokensInBuffer,
&tle::CacheTransceiverConfig::setMaxTokensInBuffer)
.def_property("kv_transfer_timeout_ms", &tle::CacheTransceiverConfig::getKvTransferTimeoutMs,
&tle::CacheTransceiverConfig::setKvTransferTimeoutMs)
.def_property("kv_transfer_sender_future_timeout_ms",
&tle::CacheTransceiverConfig::getKvTransferSenderFutureTimeoutMs,
&tle::CacheTransceiverConfig::setKvTransferSenderFutureTimeoutMs)
.def(py::pickle(cacheTransceiverConfigGetstate, cacheTransceiverConfigSetstate));

auto executorConfigGetState = [](py::object const& self)
Expand Down
7 changes: 5 additions & 2 deletions cpp/tests/unit_tests/executor/serializeUtilsTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -789,7 +789,7 @@ TEST(SerializeUtilsTest, ExecutorConfig)
texec::GuidedDecodingConfig(
texec::GuidedDecodingConfig::GuidedDecodingBackend::kXGRAMMAR, std::initializer_list<std::string>{"eos"}),
std::vector{tensorrt_llm::executor::AdditionalModelOutput{"output_name"}},
texec::CacheTransceiverConfig(std::nullopt, 1024), true, true, true);
texec::CacheTransceiverConfig(std::nullopt, 1024, 100, 1000), true, true, true);
auto executorConfig2 = serializeDeserialize(executorConfig);

EXPECT_EQ(executorConfig.getMaxBeamWidth(), executorConfig2.getMaxBeamWidth());
Expand Down Expand Up @@ -1028,10 +1028,13 @@ TEST(SerializeUtilsTest, MethodReturnType)
TEST(SerializeUtilsTest, CacheTransceiverConfig)
{
texec::CacheTransceiverConfig cacheTransceiverConfig(
tensorrt_llm::executor::CacheTransceiverConfig::BackendType::UCX, 1024);
tensorrt_llm::executor::CacheTransceiverConfig::BackendType::UCX, 1024, 100, 1000);
auto cacheTransceiverConfig2 = serializeDeserialize(cacheTransceiverConfig);
EXPECT_EQ(cacheTransceiverConfig.getBackendType(), cacheTransceiverConfig2.getBackendType());
EXPECT_EQ(cacheTransceiverConfig.getMaxTokensInBuffer(), cacheTransceiverConfig2.getMaxTokensInBuffer());
EXPECT_EQ(cacheTransceiverConfig.getKvTransferTimeoutMs(), cacheTransceiverConfig2.getKvTransferTimeoutMs());
EXPECT_EQ(cacheTransceiverConfig.getKvTransferSenderFutureTimeoutMs(),
cacheTransceiverConfig2.getKvTransferSenderFutureTimeoutMs());
}

TEST(SerializeUtilsTest, BlockKeyBasic)
Expand Down
2 changes: 2 additions & 0 deletions examples/disaggregated/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ cache_transceiver_config:
# KV cache transfer timeout in milliseconds
# For requests, if they do not send/receive the KV cache in time they are cancelled and cleaned up
kv_transfer_timeout_ms: <int>
# Timeout in milliseconds to wait for the sender future to be ready when scheduled batch size is 0. This allows the request to be eventually cancelled by the user or because of kv_transfer_timeout_ms
kv_transfer_sender_future_timeout_ms: <int>
```
The following is an example, consisting of the `ctx_extra-llm-api-config.yaml` and `gen_extra-llm-api-config.yaml` files needed in the sections below.
Expand Down
1 change: 1 addition & 0 deletions tensorrt_llm/_torch/pyexecutor/kv_cache_transceiver.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ def __init__(self, mapping: Mapping, dist: Distributed,
pp_layer_num_per_pp_rank = dist.pp_allgather(pp_layer_num)

self.kv_transfer_timeout_ms = cache_transceiver_config.kv_transfer_timeout_ms
self.kv_transfer_sender_future_timeout_ms = cache_transceiver_config.kv_transfer_sender_future_timeout_ms
self.impl = CacheTransceiverCpp(kv_cache_manager.impl,
total_num_kv_heads_per_layer, head_dim,
tokens_per_block, world_config,
Expand Down
11 changes: 10 additions & 1 deletion tensorrt_llm/llmapi/llm_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -1563,11 +1563,20 @@ class CacheTransceiverConfig(StrictBaseModel, PybindMirror):
"Timeout in milliseconds for KV cache transfer. Requests exceeding this timeout will be cancelled."
)

kv_transfer_sender_future_timeout_ms: Optional[int] = Field(
default=1000,
gt=0,
description=
"Timeout in milliseconds to wait for the sender future to be ready when scheduled batch size is 0. This allows the request to be eventually cancelled by the user or because of kv_transfer_timeout_ms"
)

def _to_pybind(self):
return _CacheTransceiverConfig(
backend=_CacheTransceiverBackendType.from_string(self.backend),
max_tokens_in_buffer=self.max_tokens_in_buffer,
kv_transfer_timeout_ms=self.kv_transfer_timeout_ms)
kv_transfer_timeout_ms=self.kv_transfer_timeout_ms,
kv_transfer_sender_future_timeout_ms=self.
kv_transfer_sender_future_timeout_ms)


@dataclass
Expand Down
79 changes: 79 additions & 0 deletions tests/unittest/llmapi/test_llm_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1032,3 +1032,82 @@ def test_llm_context_only_timed_out():
final_used_num_blocks = results[0]["kvCacheStats"]["usedNumBlocks"]

assert final_used_num_blocks == 0


# This test is to verify that when the KV cache is exhausted and scheduled batch size is 0, the context only request will be aborted due to timeout.


@pytest.mark.threadleak(enabled=False)
@pytest.mark.part0
@skip_ray
@pytest.mark.parametrize("sender_future_timeout_ms", [100, 1000])
def test_llm_context_only_timed_out_kv_cache_exhausted(
sender_future_timeout_ms):
tp_size = 1
use_overlap = False
enable_iter_req_stats = False

llm_args_extra = {}

llm_args_extra.update(
dict(enable_iter_perf_stats=True,
enable_iter_req_stats=enable_iter_req_stats,
disable_overlap_scheduler=not use_overlap))

kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.1,
max_tokens=1000,
enable_block_reuse=False)
llm = LLM(
model=llama_model_path,
kv_cache_config=kv_cache_config,
tensor_parallel_size=tp_size,
cache_transceiver_config=CacheTransceiverConfig(
backend="DEFAULT",
kv_transfer_timeout_ms=1000,
kv_transfer_sender_future_timeout_ms=sender_future_timeout_ms),
**llm_args_extra)

max_tokens = 1
sampling_params = SamplingParams(max_tokens=max_tokens)

disaggregated_params = DisaggregatedParams(request_type="context_only")

prompts0 = [
"What is your name?",
]
prompts1 = [
"lorem ipsum dolor sit amet consectetur adipiscing elit sed do eiusmod tempor incididunt ut labore et dolore magna aliqua "
* 10
]

# Send context-only request
for output in llm.generate(prompts1 * 10,
sampling_params=sampling_params,
disaggregated_params=disaggregated_params):
print(output)

max_retries = 10
all_results = []
for _ in range(max_retries):
results = llm.get_stats(2)
all_results.extend(results)

assert len(all_results) > 0

context_only_used_num_blocks = all_results[-1]["kvCacheStats"][
"usedNumBlocks"]
print(f"Context only used num blocks: {context_only_used_num_blocks}")

# Sleep 5 seconds to allow context only request to time out
time.sleep(5)

# Send regular request
for output in llm.generate(prompts0, sampling_params=sampling_params):
print(output)

# Get number of allocated blocks
results = llm.get_stats(2)
assert len(results) == 1
final_used_num_blocks = results[0]["kvCacheStats"]["usedNumBlocks"]

assert final_used_num_blocks == 0