Skip to content

Commit 903ca3f

Browse files
committed
Fix CI
1 parent b1fa2bb commit 903ca3f

File tree

8 files changed

+43
-55
lines changed

8 files changed

+43
-55
lines changed

ai_infra_bench/sgl/cmp_bench.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,10 @@
1111
check_str_list_str,
1212
check_values_in_features_metrics,
1313
)
14+
from ai_infra_bench.constants import FULL_DATA_JSON_PATH
1415
from ai_infra_bench.modes.cmp import cmp_export_table
1516
from ai_infra_bench.modes.gen import gen_export_csv, gen_run
1617
from ai_infra_bench.utils import (
17-
FULL_DATA_JSON_PATH,
1818
kill_process_tree,
1919
maybe_create_labels,
2020
maybe_warmup,

ai_infra_bench/sgl/gen_bench.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,9 @@
1010
check_server_labels,
1111
check_str_list_str,
1212
)
13+
from ai_infra_bench.constants import FULL_DATA_JSON_PATH
1314
from ai_infra_bench.modes.gen import gen_export_csv, gen_export_table, gen_plot, gen_run
1415
from ai_infra_bench.utils import (
15-
FULL_DATA_JSON_PATH,
1616
kill_process_tree,
1717
maybe_create_labels,
1818
maybe_warmup,

ai_infra_bench/sgl/slo_bench.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,9 @@
1414
check_server_labels,
1515
check_str_list_str,
1616
)
17+
from ai_infra_bench.constants import FULL_DATA_JSON_PATH
1718
from ai_infra_bench.modes.slo import slo_run
1819
from ai_infra_bench.utils import (
19-
FULL_DATA_JSON_PATH,
2020
add_request_rate,
2121
kill_process_tree,
2222
maybe_create_labels,

examples/client_gen.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import os
2+
from typing import List
23

34
from ai_infra_bench import client_gen
45

@@ -39,8 +40,8 @@
3940
--max-concurrency {request_rate}
4041
--num-prompt {num_prompt}
4142
"""
42-
rate_lists = [1, 2, 4, 8]
43-
client_cmds = [
43+
rate_lists: List[int] = [1, 2, 4, 8]
44+
client_cmds: List[str] = [
4445
*[
4546
client_template.format(
4647
base_url=base_url,

test/modes/test_cmp.py

Lines changed: 29 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,17 @@
66
from utils import check_output_content, input_features, output_metrics
77

88
from ai_infra_bench.client import client_cmp
9+
from ai_infra_bench.constants import CSV_NAME, TABLE_NAME, WARMUP_FILE
910
from ai_infra_bench.sgl import cmp_bench
10-
from ai_infra_bench.utils import CSV_NAME, TABLE_NAME, WARMUP_FILE, ServerAccessInfo
11+
from ai_infra_bench.utils import ServerAccessInfo
1112

1213

1314
class TestClientCmp(unittest.TestCase):
1415
@classmethod
1516
def setUpClass(cls):
16-
cls.server_access_info = ServerAccessInfo(base_url="http://127.0.0.1:8000")
17+
cls.server_access_infos = ServerAccessInfo(
18+
base_url="http://127.0.0.1:8000", label="Test"
19+
)
1720
cls.client_cmds = """
1821
python -m sglang.bench_serving \
1922
--backend sglang-oai
@@ -28,10 +31,10 @@ def setUpClass(cls):
2831
--num-prompt 40
2932
"""
3033

31-
def run_single_cmd(self, server_access_info, client_cmds, **kwargs):
34+
def run_single_cmd(self, server_access_infos, client_cmds, **kwargs):
3235
with tempfile.TemporaryDirectory() as output_dir:
3336
client_cmp(
34-
server_access_info=server_access_info,
37+
server_access_infos=server_access_infos,
3538
client_cmds=client_cmds,
3639
input_features=input_features,
3740
output_metrics=output_metrics,
@@ -44,106 +47,85 @@ def run_single_cmd(self, server_access_info, client_cmds, **kwargs):
4447

4548
################## BASIC RUN ##########################
4649
def test_single_run(self):
47-
self.run_single_cmd(self.server_access_info, self.client_cmds)
50+
self.run_single_cmd(self.server_access_infos, self.client_cmds)
4851

4952
def test_multiple_run(self):
50-
self.run_single_cmd([self.server_access_info] * 2, [self.client_cmds] * 3)
53+
self.run_single_cmd([self.server_access_infos] * 2, [self.client_cmds] * 3)
5154

52-
self.run_single_cmd([self.server_access_info] * 2, self.client_cmds)
55+
self.run_single_cmd([self.server_access_infos] * 2, self.client_cmds)
5356

54-
self.run_single_cmd(self.server_access_info, [self.client_cmds] * 2)
57+
self.run_single_cmd(self.server_access_infos, [self.client_cmds] * 2)
5558

5659
def test_n_run(self):
57-
self.run_single_cmd(self.server_access_info, self.client_cmds, n=3)
60+
self.run_single_cmd(self.server_access_infos, self.client_cmds, n=3)
5861

59-
self.run_single_cmd([self.server_access_info] * 2, [self.client_cmds] * 3, n=3)
62+
self.run_single_cmd([self.server_access_infos] * 2, [self.client_cmds] * 3, n=3)
6063

61-
self.run_single_cmd([self.server_access_info] * 2, self.client_cmds, n=3)
64+
self.run_single_cmd([self.server_access_infos] * 2, self.client_cmds, n=3)
6265

63-
self.run_single_cmd(self.server_access_info, [self.client_cmds] * 2, n=3)
66+
self.run_single_cmd(self.server_access_infos, [self.client_cmds] * 2, n=3)
6467

6568
self.run_single_cmd(
66-
self.server_access_info, [self.client_cmds] * 2, n=3, only_last=True
69+
self.server_access_infos, [self.client_cmds] * 2, n=3, only_last=True
6770
)
6871

6972
################## LABEL SETTING ##########################
7073
def test_client_labels(self):
7174
self.run_single_cmd(
72-
self.server_access_info, self.client_cmds, client_labels=["client1"]
75+
self.server_access_infos, self.client_cmds, client_labels=["client1"]
7376
)
7477

7578
self.run_single_cmd(
76-
self.server_access_info, self.client_cmds, client_labels="client1"
79+
self.server_access_infos, self.client_cmds, client_labels="client1"
7780
)
7881

7982
self.run_single_cmd(
80-
self.server_access_info, [self.client_cmds] * 2, client_labels="client1"
83+
self.server_access_infos, [self.client_cmds] * 2, client_labels="client1"
8184
)
8285

8386
self.run_single_cmd(
84-
self.server_access_info,
87+
self.server_access_infos,
8588
[self.client_cmds] * 2,
8689
client_labels=["client1"] * 2,
8790
)
8891

8992
@unittest.expectedFailure
9093
def test_failed_client_labels(self):
9194
self.run_single_cmd(
92-
self.server_access_info,
95+
self.server_access_infos,
9396
[self.client_cmds] * 2,
9497
client_labels=["client1"] * 3,
9598
)
9699

97-
def test_server_labels(self):
98-
self.run_single_cmd(
99-
self.server_access_info, self.client_cmds, server_labels="server_label"
100-
)
101-
self.run_single_cmd(
102-
self.server_access_info, self.client_cmds, server_labels=["server_label"]
103-
)
104-
self.run_single_cmd(
105-
[self.server_access_info] * 2,
106-
self.client_cmds,
107-
server_labels=["server_label"] * 2,
108-
)
109-
110-
@unittest.expectedFailure
111-
def test_failed_server_labels(self):
112-
self.run_single_cmd(
113-
[self.server_access_info] * 2,
114-
self.client_cmds,
115-
server_labels=["server_label"] * 3,
116-
)
117-
118100
################## EXPECTED FAIL ###########################
119101
@unittest.expectedFailure
120102
def test_failed_host(self):
121103
self.run_single_cmd(
122-
self.server_access_info, self.client_cmds + " --host 127.0.0.1"
104+
self.server_access_infos, self.client_cmds + " --host 127.0.0.1"
123105
)
124106

125107
@unittest.expectedFailure
126108
def test_failed_port(self):
127-
self.run_single_cmd(self.server_access_info, self.client_cmds + " --port 8888")
109+
self.run_single_cmd(self.server_access_infos, self.client_cmds + " --port 8888")
128110

129111
@unittest.expectedFailure
130112
def test_failed_base_url(self):
131113
self.run_single_cmd(
132-
self.server_access_info,
114+
self.server_access_infos,
133115
self.client_cmds + " --base-url http://127.0.0.1:8888",
134116
)
135117

136118
@unittest.expectedFailure
137119
def test_failed_output_file(self):
138120
self.run_single_cmd(
139-
self.server_access_info, self.client_cmds + " --output-file output.jsonl"
121+
self.server_access_infos, self.client_cmds + " --output-file output.jsonl"
140122
)
141123

142124
################## DISABLE FEATURE ##########################
143125
def test_disable_warmup(self):
144126
with tempfile.TemporaryDirectory() as output_dir:
145127
client_cmp(
146-
server_access_info=self.server_access_info,
128+
server_access_infos=self.server_access_infos,
147129
client_cmds=self.client_cmds,
148130
input_features=input_features,
149131
output_metrics=output_metrics,
@@ -155,7 +137,7 @@ def test_disable_warmup(self):
155137
def test_disable_csv(self):
156138
with tempfile.TemporaryDirectory() as output_dir:
157139
client_cmp(
158-
server_access_info=self.server_access_info,
140+
server_access_infos=self.server_access_infos,
159141
client_cmds=self.client_cmds,
160142
input_features=input_features,
161143
output_metrics=output_metrics,
@@ -167,7 +149,7 @@ def test_disable_csv(self):
167149
def test_disable_md_table(self):
168150
with tempfile.TemporaryDirectory() as output_dir:
169151
client_cmp(
170-
server_access_info=self.server_access_info,
152+
server_access_infos=self.server_access_infos,
171153
client_cmds=self.client_cmds,
172154
input_features=input_features,
173155
output_metrics=output_metrics,
@@ -179,7 +161,7 @@ def test_disable_md_table(self):
179161
def test_disable_plot(self):
180162
with tempfile.TemporaryDirectory() as output_dir:
181163
client_cmp(
182-
server_access_info=self.server_access_info,
164+
server_access_infos=self.server_access_infos,
183165
client_cmds=self.client_cmds,
184166
input_features=input_features,
185167
output_metrics=output_metrics,

test/modes/test_gen.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
from utils import check_output_content, input_features, output_metrics
77

88
from ai_infra_bench.client import client_gen
9+
from ai_infra_bench.constants import CSV_NAME, TABLE_NAME, WARMUP_FILE
910
from ai_infra_bench.sgl import gen_bench
10-
from ai_infra_bench.utils import CSV_NAME, TABLE_NAME, WARMUP_FILE
1111

1212

1313
class TestClientGen(unittest.TestCase):

test/modes/test_slo.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,8 @@ def check_slo(item: Dict) -> bool:
3030

3131

3232
from ai_infra_bench.client import client_slo
33+
from ai_infra_bench.constants import CSV_NAME, FULL_DATA_JSON_PATH, TABLE_NAME
3334
from ai_infra_bench.sgl import slo_bench
34-
from ai_infra_bench.utils import CSV_NAME, FULL_DATA_JSON_PATH, TABLE_NAME
3535

3636

3737
class TestSGLSlo(unittest.TestCase):

test/modes/utils.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
11
import os
22

3-
from ai_infra_bench.utils import CSV_NAME, FULL_DATA_JSON_PATH, TABLE_NAME, WARMUP_FILE
3+
from ai_infra_bench.constants import (
4+
CSV_NAME,
5+
FULL_DATA_JSON_PATH,
6+
TABLE_NAME,
7+
WARMUP_FILE,
8+
)
49

510
input_features = [
611
"random_input_len",

0 commit comments

Comments
 (0)