@@ -364,8 +364,7 @@ def build_tokens(hash_ids: list[int]) -> list[int]:
364364 def make_request (token_ids : list [int ]):
365365 """Create a CompletionRequest or ChatCompletionRequest with pre-tokenized IDs."""
366366 if api_type == "completion" :
367- return CompletionRequest (model = "TinyLlama" ,
368- prompt = [token_ids ])
367+ return CompletionRequest (model = "TinyLlama" , prompt = [token_ids ])
369368 else :
370369 # Use prompt_token_ids to skip tokenizer (no real model needed)
371370 return ChatCompletionRequest (
@@ -379,7 +378,7 @@ def make_request(token_ids: list[int]):
379378
380379 # -- dataset-inspired hash_ids per turn (new blocks only) -------------
381380 # Session A (the conversation under test)
382- sess_a_turn0_hids = list (range (10 )) # 10 blocks
381+ sess_a_turn0_hids = list (range (10 )) # 10 blocks
383382 sess_a_turn1_hids = list (range (100 , 103 )) # 3 new blocks
384383 sess_a_turn2_hids = list (range (200 , 202 )) # 2 new blocks
385384
@@ -391,16 +390,14 @@ def make_request(token_ids: list[int]):
391390 sess_a_turn0_tokens = build_tokens (sess_a_turn0_hids )
392391
393392 # Turn 1 accumulated: turn 0 tokens + simulated assistant reply + new user tokens
394- sess_a_turn1_tokens = build_tokens (
395- sess_a_turn0_hids + [9990 , 9991 ] + sess_a_turn1_hids
396- )
393+ sess_a_turn1_tokens = build_tokens (sess_a_turn0_hids + [9990 , 9991 ] +
394+ sess_a_turn1_hids )
397395 # (hash_ids 9990/9991 stand in for the assistant-reply blocks)
398396
399397 # Turn 2 accumulated: extends turn 1 further
400- sess_a_turn2_tokens = build_tokens (
401- sess_a_turn0_hids + [9990 , 9991 ] + sess_a_turn1_hids + [9992 , 9993 ]
402- + sess_a_turn2_hids
403- )
398+ sess_a_turn2_tokens = build_tokens (sess_a_turn0_hids + [9990 , 9991 ] +
399+ sess_a_turn1_hids + [9992 , 9993 ] +
400+ sess_a_turn2_hids )
404401
405402 sess_b_tokens = build_tokens (sess_b_turn0_hids )
406403
@@ -426,7 +423,8 @@ def make_request(token_ids: list[int]):
426423 # Verify block hashes are disjoint between sessions
427424 blocks_a = set (info_a0 ["block_hashes" ][0 ])
428425 blocks_b = set (info_b0 ["block_hashes" ][0 ])
429- assert blocks_a .isdisjoint (blocks_b ), "Different sessions must not share block hashes"
426+ assert blocks_a .isdisjoint (
427+ blocks_b ), "Different sessions must not share block hashes"
430428
431429 # -- Round 2: turn 1 of session A (prefix extends turn 0) ------------
432430 req_a1 = make_request (sess_a_turn1_tokens )
@@ -436,16 +434,14 @@ def make_request(token_ids: list[int]):
436434 assert server_a1 == server_a , (
437435 f"Turn 1 must route to the same server as turn 0 ({ server_a } ) "
438436 f"due to KV cache prefix hit, but got { server_a1 } . "
439- f"Matches: { info_a1 ['matches' ]} "
440- )
437+ f"Matches: { info_a1 ['matches' ]} " )
441438
442439 # The match count on server_a must equal the prefix overlap
443440 server_a_idx = list (router ._server_state .keys ()).index (server_a )
444441 expected_prefix_match = len (sess_a_turn0_hids ) * tokens_per_block
445442 assert info_a1 ["matches" ][server_a_idx ] == expected_prefix_match , (
446443 f"Expected { expected_prefix_match } matched tokens on server_a, "
447- f"got { info_a1 ['matches' ][server_a_idx ]} "
448- )
444+ f"got { info_a1 ['matches' ][server_a_idx ]} " )
449445
450446 # Update server_a cache with new blocks from turn 1
451447 router ._server_state [server_a ].add_blocks (info_a1 ["block_hashes" ][0 ])
@@ -458,17 +454,16 @@ def make_request(token_ids: list[int]):
458454 assert server_a2 == server_a , (
459455 f"Turn 2 must route to the same server as turns 0-1 ({ server_a } ) "
460456 f"due to KV cache prefix hit, but got { server_a2 } . "
461- f"Matches: { info_a2 ['matches' ]} "
462- )
457+ f"Matches: { info_a2 ['matches' ]} " )
463458
464459 # Turn 2 should match all of turn 0 + turn 1 prefix blocks
465460 expected_full_match = (
466- len (sess_a_turn0_hids ) + 2 + len (sess_a_turn1_hids ) # turn0 + reply + turn1
461+ len (sess_a_turn0_hids ) + 2 +
462+ len (sess_a_turn1_hids ) # turn0 + reply + turn1
467463 ) * tokens_per_block
468464 assert info_a2 ["matches" ][server_a_idx ] == expected_full_match , (
469465 f"Expected { expected_full_match } matched tokens on turn 2, "
470- f"got { info_a2 ['matches' ][server_a_idx ]} "
471- )
466+ f"got { info_a2 ['matches' ][server_a_idx ]} " )
472467
473468 # -- Verify session B still routes to its own server ------------------
474469 req_b1 = make_request (sess_b_tokens )
@@ -477,8 +472,7 @@ def make_request(token_ids: list[int]):
477472
478473 assert server_b1 == server_b , (
479474 f"Session B should route to its original server ({ server_b } ), "
480- f"but got { server_b1 } "
481- )
475+ f"but got { server_b1 } " )
482476
483477
484478def test_create_router (servers ):
0 commit comments