Skip to content

Commit b1fa2bb

Browse files
committed
Create constants
1 parent 19575f1 commit b1fa2bb

File tree

8 files changed

+108
-112
lines changed

8 files changed

+108
-112
lines changed

ai_infra_bench/check.py

Lines changed: 1 addition & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from datetime import datetime
55
from typing import List
66

7+
from ai_infra_bench.constants import DEFAULT_BENCH_SERVING_PATH, SGLANG_KEYS
78
from ai_infra_bench.utils import is_ci
89

910
try:
@@ -15,8 +16,6 @@
1516

1617
logger = logging.getLogger(__name__)
1718

18-
DEFAULT_BENCH_SERVING_PATH = "/tmp/ai_infra_bench/bench_serving.py"
19-
2019

2120
def ensure_bench_serving_available() -> None:
2221
"""
@@ -127,45 +126,6 @@ def install_bench_serving_dependencies() -> None:
127126
)
128127

129128

130-
SGLANG_KEYS = [
131-
"backend",
132-
"dataset_name",
133-
"request_rate",
134-
"max_concurrency",
135-
"sharegpt_output_len",
136-
"random_input_len",
137-
"random_output_len",
138-
"random_range_ratio",
139-
"duration",
140-
"completed",
141-
"total_input_tokens",
142-
"total_output_tokens",
143-
"total_output_tokens_retokenized",
144-
"request_throughput",
145-
"input_throughput",
146-
"output_throughput",
147-
"mean_e2e_latency_ms",
148-
"median_e2e_latency_ms",
149-
"std_e2e_latency_ms",
150-
"p99_e2e_latency_ms",
151-
"mean_ttft_ms",
152-
"median_ttft_ms",
153-
"std_ttft_ms",
154-
"p99_ttft_ms",
155-
"mean_tpot_ms",
156-
"median_tpot_ms",
157-
"std_tpot_ms",
158-
"p99_tpot_ms",
159-
"mean_itl_ms",
160-
"median_itl_ms",
161-
"std_itl_ms",
162-
"p95_itl_ms",
163-
"p99_itl_ms",
164-
"concurrency",
165-
"accept_length",
166-
]
167-
168-
169129
def check_dir(output_dir: str, full_data_json_path):
170130
"""
171131
Checks if the specified output directory exists. If it does, it prompts the user

ai_infra_bench/client.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,11 @@
1010
check_str_list_str,
1111
check_values_in_features_metrics,
1212
)
13+
from ai_infra_bench.constants import FULL_DATA_JSON_PATH
1314
from ai_infra_bench.modes.cmp import cmp_export_table
1415
from ai_infra_bench.modes.gen import gen_export_csv, gen_export_table, gen_plot, gen_run
1516
from ai_infra_bench.modes.slo import slo_run
1617
from ai_infra_bench.utils import (
17-
FULL_DATA_JSON_PATH,
1818
ServerAccessInfo,
1919
add_request_rate,
2020
cmp_preprocess_client_cmds,
@@ -171,6 +171,7 @@ def client_gen(
171171
input_features=input_features,
172172
output_metrics=output_metrics,
173173
output_dir=output_dir,
174+
server_label=server_labels[0],
174175
)
175176

176177
if not disable_csv:

ai_infra_bench/constants.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
DEFAULT_BENCH_SERVING_PATH = "/tmp/ai_infra_bench/bench_serving.py"
2+
3+
FULL_DATA_JSON_PATH = "full_data_json" # used to store all json files
4+
TABLE_NAME = "table.md"
5+
CSV_NAME = "data.csv"
6+
WARMUP_FILE = ".warmup.json"
7+
COLORS = ["#1f77b4", "#ff7f0e", "#2ca02c", "#d62728", "#9467bd", "#8c564b"]
8+
GRAPH_PER_ROW = 3
9+
10+
SGLANG_KEYS = [
11+
"backend",
12+
"dataset_name",
13+
"request_rate",
14+
"max_concurrency",
15+
"sharegpt_output_len",
16+
"random_input_len",
17+
"random_output_len",
18+
"random_range_ratio",
19+
"duration",
20+
"completed",
21+
"total_input_tokens",
22+
"total_output_tokens",
23+
"total_output_tokens_retokenized",
24+
"request_throughput",
25+
"input_throughput",
26+
"output_throughput",
27+
"mean_e2e_latency_ms",
28+
"median_e2e_latency_ms",
29+
"std_e2e_latency_ms",
30+
"p99_e2e_latency_ms",
31+
"mean_ttft_ms",
32+
"median_ttft_ms",
33+
"std_ttft_ms",
34+
"p99_ttft_ms",
35+
"mean_tpot_ms",
36+
"median_tpot_ms",
37+
"std_tpot_ms",
38+
"p99_tpot_ms",
39+
"mean_itl_ms",
40+
"median_itl_ms",
41+
"std_itl_ms",
42+
"p95_itl_ms",
43+
"p99_itl_ms",
44+
"concurrency",
45+
"accept_length",
46+
]
47+
48+
49+
demo_output = {
50+
"backend": "sglang-oai",
51+
"dataset_name": "random",
52+
"request_rate": 10.0,
53+
"max_concurrency": 10,
54+
"sharegpt_output_len": None,
55+
"random_input_len": 1200,
56+
"random_output_len": 800,
57+
"random_range_ratio": 1.0,
58+
"duration": 45.11868940386921,
59+
"completed": 100,
60+
"total_input_tokens": 120000,
61+
"total_output_tokens": 80000,
62+
"total_output_tokens_retokenized": 79998,
63+
"request_throughput": 2.2163764356024127,
64+
"input_throughput": 2659.6517227228956,
65+
"output_throughput": 1773.1011484819303,
66+
"mean_e2e_latency_ms": 4482.026166650467,
67+
"median_e2e_latency_ms": 4487.435979535803,
68+
"std_e2e_latency_ms": 32.15524448450066,
69+
"p99_e2e_latency_ms": 4534.823208898306,
70+
"mean_ttft_ms": 38.534140698611736,
71+
"median_ttft_ms": 42.44273528456688,
72+
"std_ttft_ms": 10.558202315257851,
73+
"p99_ttft_ms": 61.15902605932206,
74+
"mean_tpot_ms": 5.561316678287678,
75+
"median_tpot_ms": 5.56157646876747,
76+
"std_tpot_ms": 0.04168330778296244,
77+
"p99_tpot_ms": 5.627061070545631,
78+
"mean_itl_ms": 5.561935330397016,
79+
"median_itl_ms": 5.495080258697271,
80+
"std_itl_ms": 1.1977701758121588,
81+
"p95_itl_ms": 6.047771545127034,
82+
"p99_itl_ms": 6.62423954345286,
83+
"concurrency": 9.933857179517508,
84+
"accept_length": None,
85+
}

ai_infra_bench/modes/cmp.py

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,8 @@
44
import plotly.graph_objects as go
55
from plotly.subplots import make_subplots
66

7-
from ai_infra_bench.utils import (
8-
TABLE_NAME,
9-
avg_std_strf,
10-
colors,
11-
enter_decorate,
12-
graph_per_row,
13-
)
7+
from ai_infra_bench.constants import COLORS, GRAPH_PER_ROW, TABLE_NAME
8+
from ai_infra_bench.utils import avg_std_strf, enter_decorate
149

1510

1611
@enter_decorate("PLOT TO HTML", filename="<input_feature>.html")
@@ -23,8 +18,8 @@ def cmp_plot(data, input_features, metrics, labels, output_dir):
2318

2419
# there are totally len(input_features) html files
2520
for input_feature in input_features:
26-
rows = (len(metrics) - 1) // graph_per_row + 1
27-
cols = graph_per_row
21+
rows = (len(metrics) - 1) // GRAPH_PER_ROW + 1
22+
cols = GRAPH_PER_ROW
2823
fig = make_subplots(rows=rows, cols=cols)
2924

3025
# there totally are len(metric) subplots
@@ -47,7 +42,7 @@ def cmp_plot(data, input_features, metrics, labels, output_dir):
4742
mode="lines+markers",
4843
marker=dict(size=8),
4944
line=dict(
50-
color=colors[server_idx % len(colors)],
45+
color=COLORS[server_idx % len(COLORS)],
5146
width=3,
5247
),
5348
hovertemplate=f"<br>{input_feature}: %{{x}}<br>{metric}: %{{y}}<br><extra></extra>",
@@ -60,7 +55,7 @@ def cmp_plot(data, input_features, metrics, labels, output_dir):
6055

6156
# one subplot is over
6257
cur_col += 1
63-
if cur_col == graph_per_row:
58+
if cur_col == GRAPH_PER_ROW:
6459
cur_col = 0
6560
cur_row += 1
6661

ai_infra_bench/modes/gen.py

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9,17 +9,14 @@
99
from plotly.subplots import make_subplots
1010
from tqdm import tqdm
1111

12-
from ai_infra_bench.utils import (
12+
from ai_infra_bench.constants import (
13+
COLORS,
1314
CSV_NAME,
1415
FULL_DATA_JSON_PATH,
16+
GRAPH_PER_ROW,
1517
TABLE_NAME,
16-
avg_std_strf,
17-
colors,
18-
enter_decorate,
19-
graph_per_row,
20-
read_jsonl,
21-
run_cmd,
2218
)
19+
from ai_infra_bench.utils import avg_std_strf, enter_decorate, read_jsonl, run_cmd
2320

2421
logger = logging.getLogger(__name__)
2522

@@ -103,22 +100,22 @@ def gen_plot(
103100
):
104101
for feature in input_features:
105102
num_graphs = len(output_metrics)
106-
num_rows = math.ceil(num_graphs / graph_per_row)
103+
num_rows = math.ceil(num_graphs / GRAPH_PER_ROW)
107104

108-
fig = make_subplots(rows=num_rows, cols=graph_per_row)
105+
fig = make_subplots(rows=num_rows, cols=GRAPH_PER_ROW)
109106
x_values = [
110107
np.mean([item[feature] for item in client])
111108
for client in all_clients_results
112109
]
113110

114111
for idx, metric in enumerate(output_metrics):
115-
row, col = divmod(idx, graph_per_row)
112+
row, col = divmod(idx, GRAPH_PER_ROW)
116113

117114
y_values = [
118115
np.mean([item[metric] for item in client])
119116
for client in all_clients_results
120117
]
121-
color = colors[idx % len(colors)]
118+
color = COLORS[idx % len(COLORS)]
122119

123120
fig.add_trace(
124121
go.Scatter(
@@ -141,7 +138,7 @@ def gen_plot(
141138
title_text=f"{server_label} - {feature}" if server_label else feature,
142139
showlegend=True,
143140
height=300 * num_rows,
144-
width=400 * graph_per_row,
141+
width=400 * GRAPH_PER_ROW,
145142
margin=dict(t=50, b=30, l=30, r=30),
146143
)
147144

ai_infra_bench/modes/slo.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44

55
import numpy as np
66

7+
from ai_infra_bench.constants import FULL_DATA_JSON_PATH
78
from ai_infra_bench.utils import (
8-
FULL_DATA_JSON_PATH,
99
add_request_rate,
1010
enter_decorate,
1111
read_jsonl,

ai_infra_bench/utils.py

Lines changed: 2 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
import psutil
1818
import requests
1919

20+
from ai_infra_bench.constants import WARMUP_FILE, demo_output
21+
2022

2123
@dataclass
2224
class ServerAccessInfo:
@@ -27,50 +29,6 @@ class ServerAccessInfo:
2729

2830
logger = logging.getLogger(__name__)
2931

30-
colors = ["#1f77b4", "#ff7f0e", "#2ca02c", "#d62728", "#9467bd", "#8c564b"]
31-
graph_per_row = 3
32-
FULL_DATA_JSON_PATH = "full_data_json" # used to store all json files
33-
TABLE_NAME = "table.md"
34-
CSV_NAME = "data.csv"
35-
WARMUP_FILE = ".warmup.json"
36-
demo_output = {
37-
"backend": "sglang-oai",
38-
"dataset_name": "random",
39-
"request_rate": 10.0,
40-
"max_concurrency": 10,
41-
"sharegpt_output_len": None,
42-
"random_input_len": 1200,
43-
"random_output_len": 800,
44-
"random_range_ratio": 1.0,
45-
"duration": 45.11868940386921,
46-
"completed": 100,
47-
"total_input_tokens": 120000,
48-
"total_output_tokens": 80000,
49-
"total_output_tokens_retokenized": 79998,
50-
"request_throughput": 2.2163764356024127,
51-
"input_throughput": 2659.6517227228956,
52-
"output_throughput": 1773.1011484819303,
53-
"mean_e2e_latency_ms": 4482.026166650467,
54-
"median_e2e_latency_ms": 4487.435979535803,
55-
"std_e2e_latency_ms": 32.15524448450066,
56-
"p99_e2e_latency_ms": 4534.823208898306,
57-
"mean_ttft_ms": 38.534140698611736,
58-
"median_ttft_ms": 42.44273528456688,
59-
"std_ttft_ms": 10.558202315257851,
60-
"p99_ttft_ms": 61.15902605932206,
61-
"mean_tpot_ms": 5.561316678287678,
62-
"median_tpot_ms": 5.56157646876747,
63-
"std_tpot_ms": 0.04168330778296244,
64-
"p99_tpot_ms": 5.627061070545631,
65-
"mean_itl_ms": 5.561935330397016,
66-
"median_itl_ms": 5.495080258697271,
67-
"std_itl_ms": 1.1977701758121588,
68-
"p95_itl_ms": 6.047771545127034,
69-
"p99_itl_ms": 6.62423954345286,
70-
"concurrency": 9.933857179517508,
71-
"accept_length": None,
72-
}
73-
7432

7533
def cmp_preprocess_client_cmds(
7634
client_cmds: List[str], server_access_info: ServerAccessInfo

examples/client_gen.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
# input args
66
base_url = os.environ["BASE_URL"]
7-
dataset_path = os.environ["SHAREGPT_DATASET"]
7+
dataset_path = os.environ["SHAREGPT_DATAPATH"]
88
input_features = [
99
"random_input_len",
1010
"random_output_len",
@@ -85,5 +85,5 @@
8585
server_label="qwen3_06b",
8686
n=3,
8787
only_last=True,
88-
output_dir="output",
88+
output_dir="client_gen_output",
8989
)

0 commit comments

Comments
 (0)