@@ -320,7 +320,8 @@ async def test_kv_cache_aware_router(servers):
320320
321321
322322@pytest .mark .asyncio
323- async def test_kv_cache_aware_router_multi_turn_conversation ():
323+ @pytest .mark .parametrize ("api_type" , ["completion" , "chat" ])
324+ async def test_kv_cache_aware_router_multi_turn_conversation (api_type ):
324325 """Test that consecutive turns of a multi-turn conversation route to the
325326 same server due to KV cache prefix hits.
326327
@@ -360,6 +361,22 @@ def build_tokens(hash_ids: list[int]) -> list[int]:
360361 tokens .append (0 )
361362 return tokens
362363
364+ def make_request (token_ids : list [int ]):
365+ """Create a CompletionRequest or ChatCompletionRequest with pre-tokenized IDs."""
366+ if api_type == "completion" :
367+ return CompletionRequest (model = "TinyLlama" ,
368+ prompt = [token_ids ])
369+ else :
370+ # Use prompt_token_ids to skip tokenizer (no real model needed)
371+ return ChatCompletionRequest (
372+ model = "TinyLlama" ,
373+ messages = [{
374+ "role" : "user" ,
375+ "content" : "dummy"
376+ }],
377+ prompt_token_ids = token_ids ,
378+ )
379+
363380 # -- dataset-inspired hash_ids per turn (new blocks only) -------------
364381 # Session A (the conversation under test)
365382 sess_a_turn0_hids = list (range (10 )) # 10 blocks
@@ -374,7 +391,6 @@ def build_tokens(hash_ids: list[int]) -> list[int]:
374391 sess_a_turn0_tokens = build_tokens (sess_a_turn0_hids )
375392
376393 # Turn 1 accumulated: turn 0 tokens + simulated assistant reply + new user tokens
377- assistant_reply_tokens = [9999 ] * (tokens_per_block * 2 ) # 2 blocks of reply
378394 sess_a_turn1_tokens = build_tokens (
379395 sess_a_turn0_hids + [9990 , 9991 ] + sess_a_turn1_hids
380396 )
@@ -391,11 +407,11 @@ def build_tokens(hash_ids: list[int]) -> list[int]:
391407 # -- Round 1: initial routing (empty caches) --------------------------
392408 # Route both sessions concurrently so load-balancing spreads them to
393409 # different servers (with equal KV cache misses, ties are broken by load).
394- req_a0 = CompletionRequest ( model = "TinyLlama" , prompt = [ sess_a_turn0_tokens ] )
410+ req_a0 = make_request ( sess_a_turn0_tokens )
395411 server_a , info_a0 = await router .get_next_server (req_a0 )
396412 # Do NOT finish req_a0 yet — keep its load active so session B avoids server_a
397413
398- req_b0 = CompletionRequest ( model = "TinyLlama" , prompt = [ sess_b_tokens ] )
414+ req_b0 = make_request ( sess_b_tokens )
399415 server_b , info_b0 = await router .get_next_server (req_b0 )
400416
401417 # Now finish both and populate caches
@@ -413,7 +429,7 @@ def build_tokens(hash_ids: list[int]) -> list[int]:
413429 assert blocks_a .isdisjoint (blocks_b ), "Different sessions must not share block hashes"
414430
415431 # -- Round 2: turn 1 of session A (prefix extends turn 0) ------------
416- req_a1 = CompletionRequest ( model = "TinyLlama" , prompt = [ sess_a_turn1_tokens ] )
432+ req_a1 = make_request ( sess_a_turn1_tokens )
417433 server_a1 , info_a1 = await router .get_next_server (req_a1 )
418434 await router .finish_request (req_a1 )
419435
@@ -435,7 +451,7 @@ def build_tokens(hash_ids: list[int]) -> list[int]:
435451 router ._server_state [server_a ].add_blocks (info_a1 ["block_hashes" ][0 ])
436452
437453 # -- Round 3: turn 2 of session A (prefix extends turn 1) ------------
438- req_a2 = CompletionRequest ( model = "TinyLlama" , prompt = [ sess_a_turn2_tokens ] )
454+ req_a2 = make_request ( sess_a_turn2_tokens )
439455 server_a2 , info_a2 = await router .get_next_server (req_a2 )
440456 await router .finish_request (req_a2 )
441457
@@ -455,7 +471,7 @@ def build_tokens(hash_ids: list[int]) -> list[int]:
455471 )
456472
457473 # -- Verify session B still routes to its own server ------------------
458- req_b1 = CompletionRequest ( model = "TinyLlama" , prompt = [ sess_b_tokens ] )
474+ req_b1 = make_request ( sess_b_tokens )
459475 server_b1 , info_b1 = await router .get_next_server (req_b1 )
460476 await router .finish_request (req_b1 )
461477
0 commit comments