Skip to content

Commit c11ce69

Browse files
committed
Add testing and address review comments
Signed-off-by: Iman Tabrizian <[email protected]>
1 parent 3182807 commit c11ce69

File tree

8 files changed

+236
-57
lines changed

8 files changed

+236
-57
lines changed

cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1797,10 +1797,8 @@ void WindowBlockManager::unpinBlocksById(std::vector<KVCacheBlock::IdType> const
17971797

17981798
for (auto const& blockId : blockIds)
17991799
{
1800-
if (blockId < 0 || static_cast<size_t>(blockId) >= mAllBlocksById.size())
1801-
{
1802-
continue;
1803-
}
1800+
TLLM_CHECK_WITH_INFO(blockId >= 0 && static_cast<size_t>(blockId) < mAllBlocksById.size(),
1801+
"Block id %d is out of range", blockId);
18041802
auto block = mAllBlocksById[blockId];
18051803
if (block && block->getBlockId() != KVCacheBlock::kCachedBlocksRootId)
18061804
{

cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -363,22 +363,7 @@ void tb::kv_cache_manager::KVCacheManagerBindings::initBindings(nb::module_& m)
363363
nb::call_guard<nb::gil_scoped_release>())
364364
.def("add_token", &BaseKVCacheManager::addToken, nb::call_guard<nb::gil_scoped_release>())
365365
.def("add_sequence", &BaseKVCacheManager::addSequence, nb::call_guard<nb::gil_scoped_release>())
366-
.def(
367-
"remove_sequence",
368-
[](tbk::BaseKVCacheManager& self, tb::LlmRequest::RequestIdType requestId, tb::LlmRequest const* llmRequest,
369-
bool pinOnRelease)
370-
{
371-
if (llmRequest != nullptr)
372-
{
373-
return self.removeSequence(requestId, *llmRequest, pinOnRelease);
374-
}
375-
else
376-
{
377-
return self.removeSequence(requestId, std::nullopt, pinOnRelease);
378-
}
379-
},
380-
nb::arg("request_id"), nb::arg("llm_request") = nullptr, nb::arg("pin_on_release") = false,
381-
nb::call_guard<nb::gil_scoped_release>())
366+
.def("remove_sequence", &BaseKVCacheManager::removeSequence, nb::call_guard<nb::gil_scoped_release>())
382367
.def("pin_blocks", &BaseKVCacheManager::pinBlocks, nb::call_guard<nb::gil_scoped_release>())
383368
.def("scheduling_remove_sequence", &BaseKVCacheManager::schedulingRemoveSequence,
384369
nb::call_guard<nb::gil_scoped_release>())

cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -367,22 +367,7 @@ void tb::kv_cache_manager::KVCacheManagerBindings::initBindings(py::module_& m)
367367
py::call_guard<py::gil_scoped_release>())
368368
.def("add_token", &BaseKVCacheManager::addToken, py::call_guard<py::gil_scoped_release>())
369369
.def("add_sequence", &BaseKVCacheManager::addSequence, py::call_guard<py::gil_scoped_release>())
370-
.def(
371-
"remove_sequence",
372-
[](tbk::BaseKVCacheManager& self, tb::LlmRequest::RequestIdType requestId, tb::LlmRequest const* llmRequest,
373-
bool pinOnRelease)
374-
{
375-
if (llmRequest != nullptr)
376-
{
377-
return self.removeSequence(requestId, *llmRequest, pinOnRelease);
378-
}
379-
else
380-
{
381-
return self.removeSequence(requestId, std::nullopt, pinOnRelease);
382-
}
383-
},
384-
py::arg("request_id"), py::arg("llm_request") = nullptr, py::arg("pin_on_release") = false,
385-
py::call_guard<py::gil_scoped_release>())
370+
.def("remove_sequence", &BaseKVCacheManager::removeSequence, py::call_guard<py::gil_scoped_release>())
386371
.def("pin_blocks", &BaseKVCacheManager::pinBlocks, py::call_guard<py::gil_scoped_release>())
387372
.def("scheduling_remove_sequence", &BaseKVCacheManager::schedulingRemoveSequence,
388373
py::call_guard<py::gil_scoped_release>())

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2445,12 +2445,7 @@ def _do_terminate_request(self, request: LlmRequest):
24452445
self.ctx_in_transmission_requests[request.py_request_id] = (
24462446
(request, block_id, self.ctx_in_transmission_counter))
24472447

2448-
store_blocks_for_reuse = not (self.block_reuse_enabled
2449-
and not self.kv_cache_manager.is_vswa
2450-
and self.kv_cache_transceiver
2451-
and request.is_context_only_request)
2452-
self.resource_manager.free_resources(
2453-
request, store_blocks_for_reuse=store_blocks_for_reuse)
2448+
self.resource_manager.free_resources(request)
24542449

24552450
if self.gather_all_responses or self.dist.rank == 0:
24562451
self.result_wait_queues.pop(request.py_request_id, None)

tensorrt_llm/_torch/pyexecutor/resource_manager.py

Lines changed: 4 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -685,13 +685,8 @@ def update_kv_cache_draft_token_location(self,
685685
None,
686686
)
687687

688-
def free_resources(self,
689-
request: LlmRequest,
690-
pin_on_release: bool = False,
691-
store_blocks_for_reuse: bool = True):
692-
# When store_blocks_for_reuse is False, pass None to prevent block storage
693-
llm_request = request if store_blocks_for_reuse else None
694-
return self.impl.remove_sequence(request.py_request_id, llm_request,
688+
def free_resources(self, request: LlmRequest, pin_on_release: bool = False):
689+
return self.impl.remove_sequence(request.py_request_id, request,
695690
pin_on_release)
696691

697692
def store_blocks_for_reuse(self,
@@ -1435,17 +1430,11 @@ def update_resources(self,
14351430
else:
14361431
resource_manager.update_resources(scheduled_batch)
14371432

1438-
def free_resources(self,
1439-
request: LlmRequest,
1440-
store_blocks_for_reuse: bool = True):
1433+
def free_resources(self, request: LlmRequest):
14411434
for resource_type, resource_manager in reversed(
14421435
self.resource_managers.items()):
14431436
if hasattr(resource_manager, "free_resources"):
1444-
if resource_type == ResourceManagerType.KV_CACHE_MANAGER:
1445-
resource_manager.free_resources(
1446-
request, store_blocks_for_reuse=store_blocks_for_reuse)
1447-
else:
1448-
resource_manager.free_resources(request)
1437+
resource_manager.free_resources(request)
14491438

14501439
def reorder_pipeline(self,
14511440
resource_manager_list: list[ResourceManagerType]):
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
hostname: localhost
2+
port: 8000
3+
model: DeepSeek-V3-Lite/bf16
4+
backend: "pytorch"
5+
enable_autotuner: False
6+
context_servers:
7+
disable_overlap_scheduler: True
8+
num_instances: 1
9+
tensor_parallel_size: 1
10+
pipeline_parallel_size: 1
11+
max_num_tokens: 16384
12+
max_seq_len: 32768
13+
enable_chunked_prefill: True
14+
kv_cache_config:
15+
enable_block_reuse: True
16+
enable_partial_reuse: True
17+
free_gpu_memory_fraction: 0.3
18+
cache_transceiver_config:
19+
backend: "DEFAULT"
20+
max_tokens_in_buffer: 32768
21+
cuda_graph_config:
22+
enable_padding: True
23+
max_batch_size: 1
24+
urls:
25+
- "localhost:8001"
26+
generation_servers:
27+
num_instances: 1
28+
tensor_parallel_size: 1
29+
pipeline_parallel_size: 1
30+
max_num_tokens: 2048
31+
max_seq_len: 32768
32+
enable_chunked_prefill: True
33+
kv_cache_config:
34+
enable_block_reuse: True
35+
enable_partial_reuse: True
36+
free_gpu_memory_fraction: 0.85
37+
cache_transceiver_config:
38+
backend: "DEFAULT"
39+
max_tokens_in_buffer: 32768
40+
cuda_graph_config:
41+
enable_padding: True
42+
max_batch_size: 64
43+
urls:
44+
- "localhost:8002"

tests/integration/defs/disaggregated/test_disaggregated.py

Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,8 @@ def get_test_config(test_desc, example_dir, test_root):
200200
"gpt_oss_120b_stress":
201201
(4,
202202
f"{test_configs_root}/disagg_config_ctxtp2_gentp2_gptoss_tllm.yaml"),
203+
"cancel_stress_test":
204+
(2, f"{test_configs_root}/disagg_config_cancel_stress_test.yaml"),
203205
}
204206

205207
if test_desc not in config_map:
@@ -2098,3 +2100,183 @@ def test_disaggregated_stress_test(disaggregated_test_root,
20982100
threshold=test_config.accuracy_threshold,
20992101
env=llm_venv._new_env,
21002102
cwd=llm_venv.get_working_directory())
2103+
2104+
2105+
def run_cancel_stress_test(server_url: str,
2106+
num_bursts: int = 5,
2107+
requests_per_burst: int = 32,
2108+
prompt_len_range: tuple = (2000, 8000),
2109+
cancel_after_range: tuple = (0.01, 0.1)):
2110+
"""
2111+
Stress test that sends requests with large contexts and cancels them
2112+
during prefill to test resource cleanup under cancellation.
2113+
2114+
Args:
2115+
server_url: The server URL (e.g., "http://localhost:8000")
2116+
num_bursts: Number of request bursts to send
2117+
requests_per_burst: Number of concurrent requests per burst
2118+
prompt_len_range: (min, max) prompt length in tokens
2119+
cancel_after_range: (min, max) seconds to wait before cancelling
2120+
"""
2121+
import asyncio
2122+
import random
2123+
import time
2124+
2125+
import aiohttp
2126+
2127+
async def spam_and_cancel(session, req_id, url, prompt_len_range,
2128+
cancel_after_range):
2129+
"""Send a request and cancel it during prefill."""
2130+
prompt_len = random.randint(prompt_len_range[0], prompt_len_range[1])
2131+
prompt = "test " * (prompt_len // 5)
2132+
2133+
payload = {
2134+
"model": "test-model",
2135+
"prompt": prompt,
2136+
"max_tokens": 10,
2137+
"stream": True
2138+
}
2139+
2140+
try:
2141+
cancel_after = random.uniform(cancel_after_range[0],
2142+
cancel_after_range[1])
2143+
start = time.time()
2144+
async with session.post(
2145+
f"{url}/v1/completions",
2146+
json=payload,
2147+
timeout=aiohttp.ClientTimeout(total=60)) as resp:
2148+
async for line in resp.content:
2149+
if time.time() - start > cancel_after:
2150+
# Force disconnect during prefill
2151+
break
2152+
except Exception:
2153+
pass # Connection abort is expected
2154+
2155+
async def run_bursts():
2156+
async with aiohttp.ClientSession() as session:
2157+
for burst_idx in range(num_bursts):
2158+
tasks = [
2159+
spam_and_cancel(session, i, server_url, prompt_len_range,
2160+
cancel_after_range)
2161+
for i in range(requests_per_burst)
2162+
]
2163+
await asyncio.gather(*tasks)
2164+
logger.info(
2165+
f"Completed burst {burst_idx + 1}/{num_bursts} ({requests_per_burst} requests)"
2166+
)
2167+
await asyncio.sleep(0.05)
2168+
2169+
asyncio.run(run_bursts())
2170+
2171+
2172+
def run_disaggregated_cancel_test(example_dir,
2173+
test_desc,
2174+
env=None,
2175+
cwd=None,
2176+
num_bursts=64,
2177+
requests_per_burst=64):
2178+
"""Run disaggregated test with request cancellation stress test."""
2179+
cleanup_output_files()
2180+
run_env = env.copy()
2181+
run_env["UCX_TLS"] = "^ib"
2182+
2183+
num_ranks, config_file = get_test_config(test_desc, example_dir,
2184+
os.path.dirname(__file__))
2185+
2186+
workers_cmd = [
2187+
'mpirun', '--allow-run-as-root', '--oversubscribe', '-n',
2188+
str(num_ranks), 'trtllm-serve', 'disaggregated_mpi_worker', '-c',
2189+
config_file
2190+
]
2191+
2192+
server_start_timeout = 1200
2193+
server_cmd = [
2194+
'trtllm-serve', 'disaggregated', '--server_start_timeout',
2195+
str(server_start_timeout), '-c', config_file
2196+
]
2197+
server_host, server_port = get_disagg_server_url_from_cfg(config_file)
2198+
server_url = f"http://{server_host}:{server_port}"
2199+
2200+
try:
2201+
with (open('output_workers.log', 'w') as output_workers,
2202+
popen(workers_cmd,
2203+
stdout=output_workers,
2204+
stderr=subprocess.STDOUT,
2205+
env=run_env,
2206+
cwd=cwd) as workers_proc, open('output_disagg.log', 'w') as
2207+
output_disagg,
2208+
popen(server_cmd,
2209+
stdout=output_disagg,
2210+
stderr=subprocess.STDOUT,
2211+
env=run_env,
2212+
cwd=cwd) as server_proc):
2213+
2214+
# Wait for server to be ready
2215+
if not wait_for_server(server_host,
2216+
server_port,
2217+
timeout_seconds=server_start_timeout):
2218+
raise RuntimeError(
2219+
f"Disaggregated server did not become ready within {server_start_timeout} seconds"
2220+
)
2221+
2222+
# Run the cancel stress test
2223+
run_cancel_stress_test(server_url,
2224+
num_bursts=num_bursts,
2225+
requests_per_burst=requests_per_burst)
2226+
2227+
# Verify server is still healthy after stress test by sending a normal request
2228+
client_dir = f"{example_dir}/clients"
2229+
client_cmd = [
2230+
'python3', f'{client_dir}/disagg_client.py', '-c', config_file,
2231+
'-p', f'{client_dir}/prompts.json', '--ignore-eos',
2232+
'--server-start-timeout',
2233+
str(server_start_timeout)
2234+
]
2235+
check_call(client_cmd,
2236+
env=env,
2237+
poll_procs=[workers_proc, server_proc])
2238+
2239+
except Exception:
2240+
logger.error("-------- Workers output --------")
2241+
with open('output_workers.log', 'r') as f:
2242+
logger.error(f.read())
2243+
2244+
logger.error("-------- Disagg server output --------")
2245+
with open('output_disagg.log', 'r') as f:
2246+
logger.error(f.read())
2247+
raise
2248+
finally:
2249+
if 'server_proc' in locals() and 'workers_proc' in locals():
2250+
server_proc.terminate()
2251+
workers_proc.terminate()
2252+
server_proc.wait()
2253+
workers_proc.wait()
2254+
2255+
2256+
@pytest.mark.parametrize("deepseek_v3_model_root", ['DeepSeek-V3-Lite-bf16'],
2257+
indirect=True)
2258+
def test_disaggregated_cancel_large_context_requests(disaggregated_test_root,
2259+
disaggregated_example_root,
2260+
llm_venv,
2261+
deepseek_v3_model_root):
2262+
"""
2263+
Test that the disaggregated server handles request cancellations gracefully.
2264+
2265+
This test sends bursts of requests with large contexts and cancels them
2266+
during prefill to stress test resource cleanup.
2267+
"""
2268+
src_dst_dict = {
2269+
deepseek_v3_model_root:
2270+
f"{llm_venv.get_working_directory()}/DeepSeek-V3-Lite/bf16",
2271+
}
2272+
for src, dst in src_dst_dict.items():
2273+
if not os.path.islink(dst):
2274+
os.makedirs(os.path.dirname(dst), exist_ok=True)
2275+
os.symlink(src, dst, target_is_directory=True)
2276+
2277+
run_disaggregated_cancel_test(disaggregated_example_root,
2278+
"cancel_stress_test",
2279+
env=llm_venv._new_env,
2280+
cwd=llm_venv.get_working_directory(),
2281+
num_bursts=5,
2282+
requests_per_burst=32)

tests/integration/test_lists/test-db/l0_dgx_h100.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ l0_dgx_h100:
4343
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[True-True-False]
4444
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[True-True-True]
4545
- unittest/llmapi/apps/test_disagg_serving_perf_metrics.py
46+
- disaggregated/test_disaggregated.py::test_disaggregated_cancel_large_context_requests[DeepSeek-V3-Lite-bf16]
4647
# ------------- AutoDeploy tests ---------------
4748
- accuracy/test_llm_api_autodeploy.py::TestLlama3_1_8B::test_auto_dtype[False-2]
4849
# llmapi

0 commit comments

Comments
 (0)