Skip to content

Commit ec0a5a9

Browse files
committed
Improve client usage
1 parent 84a4add commit ec0a5a9

File tree

3 files changed

+179
-26
lines changed

3 files changed

+179
-26
lines changed

examples/client_cmp.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
import os
2+
3+
from ai_infra_bench import client_cmp
4+
from ai_infra_bench.utils import ServerAccessInfo
5+
6+
# input args
7+
input_len = 1200
8+
output_len = 800
9+
dataset_path = os.environ["SHAREGPT_DATASET"]
10+
input_features = [
11+
"random_input_len",
12+
"random_output_len",
13+
"request_rate",
14+
"max_concurrency",
15+
]
16+
output_metrics = [
17+
"mean_ttft_ms",
18+
"p99_ttft_ms",
19+
"mean_tpot_ms",
20+
"p99_tpot_ms",
21+
"mean_itl_ms",
22+
"p99_itl_ms",
23+
"mean_e2e_latency_ms",
24+
"p99_e2e_latency_ms",
25+
"output_throughput",
26+
]
27+
28+
# construct client requests
29+
# don't set --base-url due to it will be contained in the server access infos
30+
client_template = """
31+
python -m sglang.bench_serving
32+
--backend sglang-oai
33+
--tokenizer Qwen/Qwen3-0.6B
34+
--model Qwen/Qwen3-0.6B
35+
--dataset-path {dataset_path}
36+
--dataset-name random
37+
--random-range-ratio 1
38+
--random-input-len {input_len}
39+
--random-output-len {output_len}
40+
--request-rate {request_rate}
41+
--max-concurrency {request_rate}
42+
--num-prompt {num_prompt}
43+
"""
44+
rate_list = [1, 2, 4, 8]
45+
client_cmds = [
46+
client_template.format(
47+
input_len=input_len,
48+
output_len=output_len,
49+
dataset_path=dataset_path,
50+
request_rate=rate,
51+
num_prompt=min(max(rate * 10, 80), 250), # clip to [80, 250]
52+
)
53+
for rate in rate_list
54+
]
55+
56+
# construct server access info
57+
server_access_infos = [
58+
ServerAccessInfo(
59+
base_url="http://localhost:8888", api_key="JustKeepMe", label="old"
60+
),
61+
ServerAccessInfo(
62+
base_url="http://localhost:8889", api_key="JustKeepMe", label="new"
63+
),
64+
]
65+
66+
67+
if __name__ == "__main__":
68+
client_cmp(
69+
server_access_infos=server_access_infos,
70+
client_cmds=client_cmds,
71+
input_features=input_features,
72+
output_metrics=output_metrics,
73+
n=3,
74+
only_last=True,
75+
output_dir="version_cmp_bench",
76+
)

examples/client_gen.py

Lines changed: 32 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,34 @@
22

33
from ai_infra_bench import client_gen
44

5+
# input args
56
base_url = os.environ["BASE_URL"]
67
dataset_path = os.environ["SHAREGPT_DATASET"]
8+
input_features = [
9+
"random_input_len",
10+
"random_output_len",
11+
"request_rate",
12+
"max_concurrency",
13+
]
14+
output_metrics = [
15+
"mean_ttft_ms",
16+
"p99_ttft_ms",
17+
"mean_tpot_ms",
18+
"p99_tpot_ms",
19+
"mean_itl_ms",
20+
"p99_itl_ms",
21+
"mean_e2e_latency_ms",
22+
"p99_e2e_latency_ms",
23+
"output_throughput",
24+
]
725

26+
# construct client requests
827
client_template = """
9-
python -m sglang.bench_serving \
28+
python -m sglang.bench_serving
1029
--base-url {base_url}
1130
--backend sglang-oai
12-
--tokenizer deepseek-ai/DeepSeek-R1-0528
13-
--model deepseek-ai/DeepSeek-R1-0528
31+
--tokenizer Qwen/Qwen3-0.6B
32+
--model Qwen/Qwen3-0.6B
1433
--dataset-path {dataset_path}
1534
--dataset-name random
1635
--random-range-ratio 1
@@ -20,27 +39,27 @@
2039
--max-concurrency {request_rate}
2140
--num-prompt {num_prompt}
2241
"""
23-
rate_lists = [1, 2, 4, 8, 16, 24, 32, 40]
42+
rate_lists = [1, 2, 4, 8]
2443
client_cmds = [
2544
*[
2645
client_template.format(
2746
base_url=base_url,
28-
input_len=2000,
29-
output_len=1500,
47+
input_len=1200,
48+
output_len=800,
3049
dataset_path=dataset_path,
3150
request_rate=rate,
32-
num_prompt=rate * 10,
51+
num_prompt=min(max(rate * 10, 80), 250), # clip to [80, 250]
3352
)
3453
for rate in rate_lists
3554
],
3655
*[
3756
client_template.format(
3857
base_url=base_url,
39-
input_len=900,
58+
input_len=800,
4059
output_len=1200,
4160
dataset_path=dataset_path,
4261
request_rate=rate,
43-
num_prompt=rate * 10,
62+
num_prompt=min(max(rate * 10, 80), 250), # clip to [80, 250]
4463
)
4564
for rate in rate_lists
4665
],
@@ -51,33 +70,20 @@
5170
output_len=1500,
5271
dataset_path=dataset_path,
5372
request_rate=rate,
54-
num_prompt=rate * 10,
73+
num_prompt=min(max(rate * 10, 80), 250), # clip to [80, 250]
5574
)
5675
for rate in rate_lists
5776
],
5877
]
5978

60-
input_features = [
61-
"random_input_len",
62-
"random_output_len",
63-
"request_rate",
64-
"max_concurrency",
65-
]
66-
67-
output_metrics = [
68-
"p99_ttft_ms",
69-
"p99_tpot_ms",
70-
"p99_itl_ms",
71-
"output_throughput",
72-
"p99_e2e_latency_ms",
73-
"completed",
74-
]
7579

7680
if __name__ == "__main__":
7781
client_gen(
7882
client_cmds=client_cmds,
7983
input_features=input_features,
8084
output_metrics=output_metrics,
81-
server_labels="deepseek_r1",
85+
server_labels="qwen3_06b",
86+
n=3,
87+
only_last=True,
8288
output_dir="output",
8389
)

examples/client_slo.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
import os
2+
from typing import Dict
3+
4+
from ai_infra_bench import client_slo
5+
6+
# input args
7+
input_len = 1200
8+
output_len = 800
9+
base_url = os.environ["BASE_URL"]
10+
dataset_path = os.environ["SHAREGPT_DATASET"]
11+
input_features = [
12+
"random_input_len",
13+
"random_output_len",
14+
"request_rate",
15+
"max_concurrency",
16+
]
17+
output_metrics = [
18+
"mean_ttft_ms",
19+
"p99_ttft_ms",
20+
"mean_tpot_ms",
21+
"p99_tpot_ms",
22+
"mean_itl_ms",
23+
"p99_itl_ms",
24+
"mean_e2e_latency_ms",
25+
"p99_e2e_latency_ms",
26+
"output_throughput",
27+
]
28+
29+
# construct client requests
30+
client_template = """
31+
python -m sglang.bench_serving
32+
--base-url {base_url}
33+
--backend sglang-oai
34+
--tokenizer Qwen/Qwen3-0.6B
35+
--model Qwen/Qwen3-0.6B
36+
--dataset-path {dataset_path}
37+
--dataset-name random
38+
--random-range-ratio 1
39+
--random-input-len {input_len}
40+
--random-output-len {output_len}
41+
"""
42+
client_cmds = client_template.format(
43+
base_url=base_url,
44+
dataset_path=dataset_path,
45+
input_len=input_len,
46+
output_len=output_len,
47+
)
48+
49+
50+
def check_slo(item: Dict) -> bool:
51+
return (
52+
item["p99_ttft_ms"] < 3000
53+
and item["p99_tpot_ms"] < 100
54+
and item["p99_itl_ms"] < 100
55+
)
56+
57+
58+
request_rates = [(20, 70)]
59+
60+
61+
if __name__ == "__main__":
62+
client_slo(
63+
client_cmds=client_cmds,
64+
input_features=input_features,
65+
output_metrics=output_metrics,
66+
check_slo=check_slo,
67+
request_rates=request_rates,
68+
n=3,
69+
only_last=True,
70+
output_dir="client_slo_output",
71+
)

0 commit comments

Comments
 (0)