Skip to content
102 changes: 102 additions & 0 deletions examples/disaggregated/slurm/benchmark/run_benchmark_aiperf.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
#!/bin/bash

# aiperf-based benchmark script for disaggregated serving
# Args: model_name dataset_file multi_round num_gen_servers concurrency_list streaming log_path hostname port ucx_warmup_requests

set -e
set -u
trap 'echo "Error occurred at line $LINENO"; exit 1' ERR

if [ "$#" -lt 10 ]; then
echo "Error: Missing required arguments, got $# arguments, args: $@"
echo "Usage: $0 model_name dataset_file multi_round num_gen_servers concurrency_list streaming log_path hostname port ucx_warmup_requests"
exit 1
fi

model_name=$1
dataset_file=$2
multi_round=$3
num_gen_servers=$4
concurrency_list=$5
streaming=$6
log_path=$7
hostname=$8
port=$9
ucx_warmup_requests=${10}

# check process id is not 0
if [[ ${SLURM_PROCID} != "0" ]]; then
echo "Process id is ${SLURM_PROCID} for loadgen, exiting"
exit 0
fi

# Always install/upgrade aiperf to ensure we have the version with trust_remote_code fix
# (container may have an older version with parallel_decode.py that lacks trust_remote_code)
echo "Installing aiperf..."
pip install --force-reinstall --no-deps 'aiperf @ git+https://github.com/ai-dynamo/aiperf.git@ac3d91652e5e024bfb4ac38d48603423aad666bc'

# warmup requests for ucx connections
if [ "${ucx_warmup_requests}" -gt 0 ]; then
echo "warming up ucx connections with small requests... ${ucx_warmup_requests}"
python -m tensorrt_llm.serve.scripts.benchmark_serving \
--model ${model_name} \
--dataset-name random \
--random-ids \
--random-input-len 100 \
--random-output-len 10 \
--num-prompts ${ucx_warmup_requests} \
--host ${hostname} \
--port ${port} \
--ignore-eos \
--trust-remote-code \
--non-streaming
echo "UCX warmup done"
fi

# Trust remote code globally for custom tokenizers in parallel workers
export HF_HUB_TRUST_REMOTE_CODE=1

echo "Hostname: ${hostname}, Port: ${port}"
echo "Starting aiperf benchmark..."

concurrency_list=$(echo "${concurrency_list}" | tr ',' ' ')
for concurrency in ${concurrency_list}; do
concurrency=$((concurrency))
request_count=$((concurrency * multi_round))
# benchmark_duration: 20min per round
benchmark_duration=$((multi_round * 1200))
echo "Benchmarking with concurrency ${concurrency} ... ${request_count} requests, duration ${benchmark_duration}s"
mkdir -p ${log_path}/concurrency_${concurrency}

aiperf profile \
-m ${model_name} \
--tokenizer ${model_name} \
--tokenizer-trust-remote-code \
--url http://${hostname}:${port} \
--streaming \
--ui simple \
--input-file ${dataset_file} \
--artifact-dir ${log_path}/concurrency_${concurrency} \
--concurrency ${concurrency} \
--concurrency-ramp-duration 60 \
--custom-dataset-type mooncake_trace \
--benchmark-duration ${benchmark_duration} \
--benchmark-grace-period 60 \
--workers-max 200 \
--request-timeout-seconds 1200 \
--profile-export-level records \
--extra-inputs ignore_eos:true \
--request-count ${request_count} \
--record-processors 8

echo "Benchmark with concurrency ${concurrency} done"
done

# Fetch perf metrics from disagg server
echo "Fetching perf metrics from http://${hostname}:${port}/perf_metrics ..."
curl -s "http://${hostname}:${port}/perf_metrics" > ${log_path}/perf_metrics.json 2>&1 || true
if [ -s "${log_path}/perf_metrics.json" ]; then
echo "Perf metrics saved to ${log_path}/perf_metrics.json"
else
echo "Warning: perf_metrics response was empty or endpoint not available"
fi
22 changes: 19 additions & 3 deletions examples/disaggregated/slurm/benchmark/submit.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,9 @@ def assign_servers(
return allocations


def convert_allocations_to_server_config(allocations, server_port=8333):
def convert_allocations_to_server_config(allocations,
server_port=8333,
router_config=None):
generation_servers = {}
context_servers = {}
server_hostname = None
Expand All @@ -127,6 +129,8 @@ def convert_allocations_to_server_config(allocations, server_port=8333):
f"{list(instance['nodes'].keys())[0]}:{instance['port']}")

server_config_entry = {'num_instances': num_servers, 'urls': urls}
if router_config:
server_config_entry['router'] = router_config.copy()

if server_type == "GEN":
generation_servers = server_config_entry
Expand Down Expand Up @@ -481,7 +485,12 @@ def submit_job(config, log_dir, dry_run):
json.dump(allocations, f, indent=2)

# Generate disagg server config
server_config = convert_allocations_to_server_config(allocations)
router_config = config.get('router_config', None)
server_config = convert_allocations_to_server_config(
allocations, router_config=router_config)
# Merge server_config_extra into disagg server config
if 'server_config_extra' in config:
server_config.update(config['server_config_extra'])
with open(os.path.join(log_dir, "server_config_base.yaml"), "w") as f:
yaml.dump(server_config, f)
disagg_server_hostname = server_config['hostname']
Expand Down Expand Up @@ -608,7 +617,14 @@ def submit_job(config, log_dir, dry_run):
benchmark_prefix = client_slurm_prefix + [
f"--export \"{convert_envs_to_str(env_var)}\""
]
if benchmark_config['use_nv_sa_benchmark']:
if benchmark_config.get('use_aiperf', False):
benchmark_cmd = [
f"bash {os.path.join(script_dir, 'run_benchmark_aiperf.sh')}",
f"'{env_config['model_path']}' '{benchmark_config['dataset_file']}' {benchmark_config['multi_round']} {gen_num} '{benchmark_config['concurrency_list']}' {benchmark_config['streaming']} '{log_dir}' {disagg_server_hostname} {disagg_server_port} {ucx_warmup_requests}",
f"&> {log_dir}/6_bench.log"
]
client_cmds.append(" ".join(benchmark_prefix + benchmark_cmd))
elif benchmark_config['use_nv_sa_benchmark']:
if benchmark_config['mode'] == "gen_only":
print(
f"[ERROR] SA benchmark client script is not supported for gen_only mode"
Expand Down
42 changes: 36 additions & 6 deletions tensorrt_llm/serve/router.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
import heapq
import os
from abc import ABC, abstractmethod
from typing import Awaitable, Callable, Dict, Iterable, List, Optional, Union

Expand Down Expand Up @@ -626,9 +627,38 @@ def __init__(self,
self._tokenizers = {}
# TODO: use max_num_tokens? per server?
self._max_batch_size = max_batch_size
env_tokens_per_block = os.environ.get(
"TRTLLM_KVCACHE_AWARE_ROUTER_HASH_TOKENS_PER_BLOCK")
if env_tokens_per_block is not None:
tokens_per_block = int(env_tokens_per_block)
self._tokens_per_block = tokens_per_block
logger.info(
f"KvCacheAwareRouter: tokens_per_block={self._tokens_per_block}")

def _get_tokenizer(self, model: str):
if model not in self._tokenizers:
self._tokenizers[model] = AutoTokenizer.from_pretrained(model)
return self._tokenizers[model]

def _tokenize(self, request: OpenAIRequest) -> list[list[int]]:
# Handle ChatCompletionRequest (has messages, not prompt)
if isinstance(request, ChatCompletionRequest):
if request.prompt_token_ids is not None:
return [request.prompt_token_ids]
tokenizer = self._get_tokenizer(request.model)
token_ids = tokenizer.apply_chat_template(
[
msg if isinstance(msg, dict) else dict(msg)
for msg in request.messages
],
add_generation_prompt=request.add_generation_prompt,
tokenize=True,
)
# Set prompt_token_ids so the worker server skips re-tokenization
request.prompt_token_ids = token_ids
return [token_ids]

# Handle CompletionRequest (has prompt)
prompts = request.prompt
if isinstance(prompts, list) and isinstance(prompts[0], list):
return prompts
Expand All @@ -639,12 +669,12 @@ def _tokenize(self, request: OpenAIRequest) -> list[list[int]]:
else:
assert isinstance(prompts, list) and isinstance(prompts[0], str)

# TODO: send tokenize-only request instead of tokenizing locally
if request.model not in self._tokenizers:
self._tokenizers[request.model] = AutoTokenizer.from_pretrained(
request.model)
tokenizer = self._tokenizers[request.model]
return [tokenizer(prompt)["input_ids"] for prompt in prompts]
tokenizer = self._get_tokenizer(request.model)
token_lists = [tokenizer(prompt)["input_ids"] for prompt in prompts]
# Replace string prompts with token IDs so the worker server
# skips re-tokenization
request.prompt = token_lists if len(token_lists) > 1 else token_lists[0]
return token_lists

async def get_next_server(
self,
Expand Down
156 changes: 156 additions & 0 deletions tests/unittest/disaggregated/test_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,162 @@ async def test_kv_cache_aware_router(servers):
assert servers == ("server3", "server1", "server2")


@pytest.mark.asyncio
@pytest.mark.parametrize("api_type", ["completion", "chat"])
async def test_kv_cache_aware_router_multi_turn_conversation(api_type):
"""Test that consecutive turns of a multi-turn conversation route to the
same server due to KV cache prefix hits.

Simulates two concurrent sessions inspired by
agentic_data/dataset_sample2000.jsonl session sess-fca58a1f44cd:
Turn 0: 68 hash_ids (system prompt + first user input)
Turn 1: 9 hash_ids (second user input, accumulated with turn 0)
Turn 2: 6 hash_ids (third user input, accumulated with turn 1)

Scaled down to 10, 3, 2 blocks for test manageability. Each hash_id
maps to a deterministic block of tokens (mirroring aiperf's
HashIdRandomGenerator). The router should prefer the server that
already caches the conversation's prefix.
"""
server_list = ["server1", "server2", "server3"]
tokens_per_block = 32

router = KvCacheAwareRouter(
server_role=None,
servers=server_list,
use_tokens=False,
max_batch_size=64,
tokens_per_block=tokens_per_block,
)

# -- helpers ----------------------------------------------------------
def hash_id_to_block(hash_id: int) -> list[int]:
"""Deterministic token block per hash_id (mirrors aiperf corpus sampling)."""
return [(hash_id * 7 + i) % 50000 for i in range(tokens_per_block)]

def build_tokens(hash_ids: list[int]) -> list[int]:
tokens = []
for hid in hash_ids:
tokens.extend(hash_id_to_block(hid))
# Append one extra token so the last full block is included in hashing.
# (KvCacheManager excludes the very last token from block keys.)
tokens.append(0)
return tokens

def make_request(token_ids: list[int]):
"""Create a CompletionRequest or ChatCompletionRequest with pre-tokenized IDs."""
if api_type == "completion":
return CompletionRequest(model="TinyLlama", prompt=[token_ids])
else:
# Use prompt_token_ids to skip tokenizer (no real model needed)
return ChatCompletionRequest(
model="TinyLlama",
messages=[{
"role": "user",
"content": "dummy"
}],
prompt_token_ids=token_ids,
)

# -- dataset-inspired hash_ids per turn (new blocks only) -------------
# Session A (the conversation under test)
sess_a_turn0_hids = list(range(10)) # 10 blocks
sess_a_turn1_hids = list(range(100, 103)) # 3 new blocks
sess_a_turn2_hids = list(range(200, 202)) # 2 new blocks

# Session B (competing traffic on a different server)
sess_b_turn0_hids = list(range(500, 510)) # 10 completely different blocks

# -- build accumulated token sequences --------------------------------
# Turn 0: just the first turn's tokens
sess_a_turn0_tokens = build_tokens(sess_a_turn0_hids)

# Turn 1 accumulated: turn 0 tokens + simulated assistant reply + new user tokens
sess_a_turn1_tokens = build_tokens(sess_a_turn0_hids + [9990, 9991] +
sess_a_turn1_hids)
# (hash_ids 9990/9991 stand in for the assistant-reply blocks)

# Turn 2 accumulated: extends turn 1 further
sess_a_turn2_tokens = build_tokens(sess_a_turn0_hids + [9990, 9991] +
sess_a_turn1_hids + [9992, 9993] +
sess_a_turn2_hids)

sess_b_tokens = build_tokens(sess_b_turn0_hids)

# -- Round 1: initial routing (empty caches) --------------------------
# Route both sessions concurrently so load-balancing spreads them to
# different servers (with equal KV cache misses, ties are broken by load).
req_a0 = make_request(sess_a_turn0_tokens)
server_a, info_a0 = await router.get_next_server(req_a0)
# Do NOT finish req_a0 yet — keep its load active so session B avoids server_a

req_b0 = make_request(sess_b_tokens)
server_b, info_b0 = await router.get_next_server(req_b0)

# Now finish both and populate caches
await router.finish_request(req_a0)
await router.finish_request(req_b0)
router._server_state[server_a].add_blocks(info_a0["block_hashes"][0])
router._server_state[server_b].add_blocks(info_b0["block_hashes"][0])

# Sanity: two sessions should land on different servers
assert server_a != server_b, "Disjoint sessions should land on different servers"

# Verify block hashes are disjoint between sessions
blocks_a = set(info_a0["block_hashes"][0])
blocks_b = set(info_b0["block_hashes"][0])
assert blocks_a.isdisjoint(
blocks_b), "Different sessions must not share block hashes"

# -- Round 2: turn 1 of session A (prefix extends turn 0) ------------
req_a1 = make_request(sess_a_turn1_tokens)
server_a1, info_a1 = await router.get_next_server(req_a1)
await router.finish_request(req_a1)

assert server_a1 == server_a, (
f"Turn 1 must route to the same server as turn 0 ({server_a}) "
f"due to KV cache prefix hit, but got {server_a1}. "
f"Matches: {info_a1['matches']}")

# The match count on server_a must equal the prefix overlap
server_a_idx = list(router._server_state.keys()).index(server_a)
expected_prefix_match = len(sess_a_turn0_hids) * tokens_per_block
assert info_a1["matches"][server_a_idx] == expected_prefix_match, (
f"Expected {expected_prefix_match} matched tokens on server_a, "
f"got {info_a1['matches'][server_a_idx]}")

# Update server_a cache with new blocks from turn 1
router._server_state[server_a].add_blocks(info_a1["block_hashes"][0])

# -- Round 3: turn 2 of session A (prefix extends turn 1) ------------
req_a2 = make_request(sess_a_turn2_tokens)
server_a2, info_a2 = await router.get_next_server(req_a2)
await router.finish_request(req_a2)

assert server_a2 == server_a, (
f"Turn 2 must route to the same server as turns 0-1 ({server_a}) "
f"due to KV cache prefix hit, but got {server_a2}. "
f"Matches: {info_a2['matches']}")

# Turn 2 should match all of turn 0 + turn 1 prefix blocks
expected_full_match = (
len(sess_a_turn0_hids) + 2 +
len(sess_a_turn1_hids) # turn0 + reply + turn1
) * tokens_per_block
assert info_a2["matches"][server_a_idx] == expected_full_match, (
f"Expected {expected_full_match} matched tokens on turn 2, "
f"got {info_a2['matches'][server_a_idx]}")

# -- Verify session B still routes to its own server ------------------
req_b1 = make_request(sess_b_tokens)
server_b1, info_b1 = await router.get_next_server(req_b1)
await router.finish_request(req_b1)

assert server_b1 == server_b, (
f"Session B should route to its original server ({server_b}), "
f"but got {server_b1}")


def test_create_router(servers):
default_router = create_router(None, servers)
assert isinstance(default_router, RoundRobinRouter)
Expand Down
Loading