Skip to content

Commit bab4fd2

Browse files
committed
[None][test] Parameterize multi-turn router test for completion and chat APIs
Add api_type parametrize ("completion", "chat") to test_kv_cache_aware_router_multi_turn_conversation so it exercises both CompletionRequest (prompt with token IDs) and ChatCompletionRequest (prompt_token_ids) code paths through the KvCacheAwareRouter. Signed-off-by: Lizhi Zhou <1432185+reasonsolo@users.noreply.github.com>
1 parent 514f06e commit bab4fd2

File tree

1 file changed

+23
-7
lines changed

1 file changed

+23
-7
lines changed

tests/unittest/disaggregated/test_router.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -320,7 +320,8 @@ async def test_kv_cache_aware_router(servers):
320320

321321

322322
@pytest.mark.asyncio
323-
async def test_kv_cache_aware_router_multi_turn_conversation():
323+
@pytest.mark.parametrize("api_type", ["completion", "chat"])
324+
async def test_kv_cache_aware_router_multi_turn_conversation(api_type):
324325
"""Test that consecutive turns of a multi-turn conversation route to the
325326
same server due to KV cache prefix hits.
326327
@@ -360,6 +361,22 @@ def build_tokens(hash_ids: list[int]) -> list[int]:
360361
tokens.append(0)
361362
return tokens
362363

364+
def make_request(token_ids: list[int]):
365+
"""Create a CompletionRequest or ChatCompletionRequest with pre-tokenized IDs."""
366+
if api_type == "completion":
367+
return CompletionRequest(model="TinyLlama",
368+
prompt=[token_ids])
369+
else:
370+
# Use prompt_token_ids to skip tokenizer (no real model needed)
371+
return ChatCompletionRequest(
372+
model="TinyLlama",
373+
messages=[{
374+
"role": "user",
375+
"content": "dummy"
376+
}],
377+
prompt_token_ids=token_ids,
378+
)
379+
363380
# -- dataset-inspired hash_ids per turn (new blocks only) -------------
364381
# Session A (the conversation under test)
365382
sess_a_turn0_hids = list(range(10)) # 10 blocks
@@ -374,7 +391,6 @@ def build_tokens(hash_ids: list[int]) -> list[int]:
374391
sess_a_turn0_tokens = build_tokens(sess_a_turn0_hids)
375392

376393
# Turn 1 accumulated: turn 0 tokens + simulated assistant reply + new user tokens
377-
assistant_reply_tokens = [9999] * (tokens_per_block * 2) # 2 blocks of reply
378394
sess_a_turn1_tokens = build_tokens(
379395
sess_a_turn0_hids + [9990, 9991] + sess_a_turn1_hids
380396
)
@@ -391,11 +407,11 @@ def build_tokens(hash_ids: list[int]) -> list[int]:
391407
# -- Round 1: initial routing (empty caches) --------------------------
392408
# Route both sessions concurrently so load-balancing spreads them to
393409
# different servers (with equal KV cache misses, ties are broken by load).
394-
req_a0 = CompletionRequest(model="TinyLlama", prompt=[sess_a_turn0_tokens])
410+
req_a0 = make_request(sess_a_turn0_tokens)
395411
server_a, info_a0 = await router.get_next_server(req_a0)
396412
# Do NOT finish req_a0 yet — keep its load active so session B avoids server_a
397413

398-
req_b0 = CompletionRequest(model="TinyLlama", prompt=[sess_b_tokens])
414+
req_b0 = make_request(sess_b_tokens)
399415
server_b, info_b0 = await router.get_next_server(req_b0)
400416

401417
# Now finish both and populate caches
@@ -413,7 +429,7 @@ def build_tokens(hash_ids: list[int]) -> list[int]:
413429
assert blocks_a.isdisjoint(blocks_b), "Different sessions must not share block hashes"
414430

415431
# -- Round 2: turn 1 of session A (prefix extends turn 0) ------------
416-
req_a1 = CompletionRequest(model="TinyLlama", prompt=[sess_a_turn1_tokens])
432+
req_a1 = make_request(sess_a_turn1_tokens)
417433
server_a1, info_a1 = await router.get_next_server(req_a1)
418434
await router.finish_request(req_a1)
419435

@@ -435,7 +451,7 @@ def build_tokens(hash_ids: list[int]) -> list[int]:
435451
router._server_state[server_a].add_blocks(info_a1["block_hashes"][0])
436452

437453
# -- Round 3: turn 2 of session A (prefix extends turn 1) ------------
438-
req_a2 = CompletionRequest(model="TinyLlama", prompt=[sess_a_turn2_tokens])
454+
req_a2 = make_request(sess_a_turn2_tokens)
439455
server_a2, info_a2 = await router.get_next_server(req_a2)
440456
await router.finish_request(req_a2)
441457

@@ -455,7 +471,7 @@ def build_tokens(hash_ids: list[int]) -> list[int]:
455471
)
456472

457473
# -- Verify session B still routes to its own server ------------------
458-
req_b1 = CompletionRequest(model="TinyLlama", prompt=[sess_b_tokens])
474+
req_b1 = make_request(sess_b_tokens)
459475
server_b1, info_b1 = await router.get_next_server(req_b1)
460476
await router.finish_request(req_b1)
461477

0 commit comments

Comments
 (0)