Skip to content

Commit 2939fb5

Browse files
committed
Add testing and address review comments
Signed-off-by: Iman Tabrizian <[email protected]>
1 parent 2d8a6db commit 2939fb5

File tree

4 files changed

+230
-4
lines changed

4 files changed

+230
-4
lines changed

cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1797,10 +1797,8 @@ void WindowBlockManager::unpinBlocksById(std::vector<KVCacheBlock::IdType> const
17971797

17981798
for (auto const& blockId : blockIds)
17991799
{
1800-
if (blockId < 0 || static_cast<size_t>(blockId) >= mAllBlocksById.size())
1801-
{
1802-
continue;
1803-
}
1800+
TLLM_CHECK_WITH_INFO(blockId >= 0 && static_cast<size_t>(blockId) < mAllBlocksById.size(),
1801+
"Block id %d is out of range", blockId);
18041802
auto block = mAllBlocksById[blockId];
18051803
if (block && block->getBlockId() != KVCacheBlock::kCachedBlocksRootId)
18061804
{
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
hostname: localhost
2+
port: 8000
3+
model: DeepSeek-V3-Lite/bf16
4+
backend: "pytorch"
5+
enable_autotuner: False
6+
context_servers:
7+
disable_overlap_scheduler: True
8+
num_instances: 1
9+
tensor_parallel_size: 1
10+
pipeline_parallel_size: 1
11+
max_num_tokens: 16384
12+
max_seq_len: 32768
13+
enable_chunked_prefill: True
14+
kv_cache_config:
15+
enable_block_reuse: True
16+
enable_partial_reuse: True
17+
free_gpu_memory_fraction: 0.3
18+
cache_transceiver_config:
19+
backend: "DEFAULT"
20+
max_tokens_in_buffer: 32768
21+
cuda_graph_config:
22+
enable_padding: True
23+
max_batch_size: 1
24+
urls:
25+
- "localhost:8001"
26+
generation_servers:
27+
num_instances: 1
28+
tensor_parallel_size: 1
29+
pipeline_parallel_size: 1
30+
max_num_tokens: 2048
31+
max_seq_len: 32768
32+
enable_chunked_prefill: True
33+
kv_cache_config:
34+
enable_block_reuse: True
35+
enable_partial_reuse: True
36+
free_gpu_memory_fraction: 0.85
37+
cache_transceiver_config:
38+
backend: "DEFAULT"
39+
max_tokens_in_buffer: 32768
40+
cuda_graph_config:
41+
enable_padding: True
42+
max_batch_size: 64
43+
urls:
44+
- "localhost:8002"

tests/integration/defs/disaggregated/test_disaggregated.py

Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,8 @@ def get_test_config(test_desc, example_dir, test_root):
200200
"gpt_oss_120b_stress":
201201
(4,
202202
f"{test_configs_root}/disagg_config_ctxtp2_gentp2_gptoss_tllm.yaml"),
203+
"cancel_stress_test":
204+
(2, f"{test_configs_root}/disagg_config_cancel_stress_test.yaml"),
203205
}
204206

205207
if test_desc not in config_map:
@@ -2098,3 +2100,184 @@ def test_disaggregated_stress_test(disaggregated_test_root,
20982100
threshold=test_config.accuracy_threshold,
20992101
env=llm_venv._new_env,
21002102
cwd=llm_venv.get_working_directory())
2103+
2104+
2105+
def run_cancel_stress_test(server_url: str,
2106+
num_bursts: int = 5,
2107+
requests_per_burst: int = 32,
2108+
prompt_len_range: tuple = (2000, 8000),
2109+
cancel_after_range: tuple = (0.01, 0.1)):
2110+
"""
2111+
Stress test that sends requests with large contexts and cancels them
2112+
during prefill to test resource cleanup under cancellation.
2113+
2114+
Args:
2115+
server_url: The server URL (e.g., "http://localhost:8000")
2116+
num_bursts: Number of request bursts to send
2117+
requests_per_burst: Number of concurrent requests per burst
2118+
prompt_len_range: (min, max) prompt length in tokens
2119+
cancel_after_range: (min, max) seconds to wait before cancelling
2120+
"""
2121+
import asyncio
2122+
import random
2123+
import time
2124+
2125+
import aiohttp
2126+
2127+
async def spam_and_cancel(session, req_id, url, prompt_len_range,
2128+
cancel_after_range):
2129+
"""Send a request and cancel it during prefill."""
2130+
prompt_len = random.randint(prompt_len_range[0], prompt_len_range[1])
2131+
prompt = "test " * (prompt_len // 5)
2132+
2133+
payload = {
2134+
"model": "test-model",
2135+
"prompt": prompt,
2136+
"max_tokens": 10,
2137+
"stream": True
2138+
}
2139+
2140+
try:
2141+
cancel_after = random.uniform(cancel_after_range[0],
2142+
cancel_after_range[1])
2143+
start = time.time()
2144+
async with session.post(
2145+
f"{url}/v1/completions",
2146+
json=payload,
2147+
timeout=aiohttp.ClientTimeout(total=60)) as resp:
2148+
async for line in resp.content:
2149+
if time.time() - start > cancel_after:
2150+
# Force disconnect during prefill
2151+
break
2152+
except Exception:
2153+
pass # Connection abort is expected
2154+
2155+
async def run_bursts():
2156+
async with aiohttp.ClientSession() as session:
2157+
for burst_idx in range(num_bursts):
2158+
tasks = [
2159+
spam_and_cancel(session, i, server_url, prompt_len_range,
2160+
cancel_after_range)
2161+
for i in range(requests_per_burst)
2162+
]
2163+
await asyncio.gather(*tasks)
2164+
logger.info(
2165+
f"Completed burst {burst_idx + 1}/{num_bursts} ({requests_per_burst} requests)"
2166+
)
2167+
await asyncio.sleep(0.05)
2168+
2169+
asyncio.run(run_bursts())
2170+
2171+
2172+
def run_disaggregated_cancel_test(example_dir,
2173+
test_desc,
2174+
env=None,
2175+
cwd=None,
2176+
num_bursts=64,
2177+
requests_per_burst=64):
2178+
"""Run disaggregated test with request cancellation stress test."""
2179+
cleanup_output_files()
2180+
run_env = env.copy()
2181+
run_env["UCX_TLS"] = "^ib"
2182+
2183+
num_ranks, config_file = get_test_config(test_desc, example_dir,
2184+
os.path.dirname(__file__))
2185+
2186+
workers_cmd = [
2187+
'mpirun', '--allow-run-as-root', '--oversubscribe', '-n',
2188+
str(num_ranks), 'trtllm-serve', 'disaggregated_mpi_worker', '-c',
2189+
config_file
2190+
]
2191+
2192+
server_start_timeout = 1200
2193+
server_cmd = [
2194+
'trtllm-serve', 'disaggregated', '--server_start_timeout',
2195+
str(server_start_timeout), '-c', config_file
2196+
]
2197+
server_host, server_port = get_disagg_server_url_from_cfg(config_file)
2198+
server_url = f"http://{server_host}:{server_port}"
2199+
2200+
try:
2201+
with (open('output_workers.log', 'w') as output_workers,
2202+
popen(workers_cmd,
2203+
stdout=output_workers,
2204+
stderr=subprocess.STDOUT,
2205+
env=run_env,
2206+
cwd=cwd) as workers_proc, open('output_disagg.log', 'w') as
2207+
output_disagg,
2208+
popen(server_cmd,
2209+
stdout=output_disagg,
2210+
stderr=subprocess.STDOUT,
2211+
env=run_env,
2212+
cwd=cwd) as server_proc):
2213+
2214+
# Wait for server to be ready
2215+
if not wait_for_server(server_host,
2216+
server_port,
2217+
timeout_seconds=server_start_timeout):
2218+
raise RuntimeError(
2219+
f"Disaggregated server did not become ready within {server_start_timeout} seconds"
2220+
)
2221+
2222+
# Run the cancel stress test
2223+
run_cancel_stress_test(server_url,
2224+
num_bursts=num_bursts,
2225+
requests_per_burst=requests_per_burst)
2226+
2227+
# Verify server is still healthy after stress test by sending a normal request
2228+
client_dir = f"{example_dir}/clients"
2229+
client_cmd = [
2230+
'python3', f'{client_dir}/disagg_client.py', '-c', config_file,
2231+
'-p', f'{client_dir}/prompts.json', '--ignore-eos',
2232+
'--server-start-timeout',
2233+
str(server_start_timeout)
2234+
]
2235+
check_call(client_cmd,
2236+
env=env,
2237+
poll_procs=[workers_proc, server_proc])
2238+
2239+
except Exception:
2240+
logger.error("-------- Workers output --------")
2241+
with open('output_workers.log', 'r') as f:
2242+
logger.error(f.read())
2243+
2244+
logger.error("-------- Disagg server output --------")
2245+
with open('output_disagg.log', 'r') as f:
2246+
logger.error(f.read())
2247+
raise
2248+
finally:
2249+
if 'server_proc' in locals() and 'workers_proc' in locals():
2250+
server_proc.terminate()
2251+
workers_proc.terminate()
2252+
server_proc.wait()
2253+
workers_proc.wait()
2254+
2255+
2256+
@pytest.mark.parametrize("deepseek_v3_model_root", ['DeepSeek-V3-Lite-bf16'],
2257+
indirect=True)
2258+
def test_disaggregated_cancel_large_context_requests(disaggregated_test_root,
2259+
disaggregated_example_root,
2260+
llm_venv,
2261+
deepseek_v3_model_root):
2262+
"""
2263+
Test that the disaggregated server handles request cancellations gracefully.
2264+
2265+
This test sends bursts of requests with large contexts and cancels them
2266+
during prefill to stress test resource cleanup.
2267+
"""
2268+
src_dst_dict = {
2269+
deepseek_v3_model_root:
2270+
f"{llm_venv.get_working_directory()}/DeepSeek-V3-Lite/bf16",
2271+
}
2272+
for src, dst in src_dst_dict.items():
2273+
if not os.path.islink(dst):
2274+
os.makedirs(os.path.dirname(dst), exist_ok=True)
2275+
os.symlink(src, dst, target_is_directory=True)
2276+
2277+
run_disaggregated_cancel_test(disaggregated_example_root,
2278+
"cancel_stress_test",
2279+
env=llm_venv._new_env,
2280+
cwd=llm_venv.get_working_directory(),
2281+
num_bursts=5,
2282+
requests_per_burst=32)
2283+

tests/integration/test_lists/test-db/l0_dgx_h100.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ l0_dgx_h100:
4343
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[True-True-False]
4444
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[True-True-True]
4545
- unittest/llmapi/apps/test_disagg_serving_perf_metrics.py
46+
- disaggregated/test_disaggregated.py::test_disaggregated_cancel_large_context_requests[DeepSeek-V3-Lite-bf16]
4647
# ------------- AutoDeploy tests ---------------
4748
- accuracy/test_llm_api_autodeploy.py::TestLlama3_1_8B::test_auto_dtype[False-2]
4849
# llmapi

0 commit comments

Comments
 (0)