66from utils import check_output_content , input_features , output_metrics
77
88from ai_infra_bench .client import client_cmp
9+ from ai_infra_bench .constants import CSV_NAME , TABLE_NAME , WARMUP_FILE
910from ai_infra_bench .sgl import cmp_bench
10- from ai_infra_bench .utils import CSV_NAME , TABLE_NAME , WARMUP_FILE , ServerAccessInfo
11+ from ai_infra_bench .utils import ServerAccessInfo
1112
1213
1314class TestClientCmp (unittest .TestCase ):
1415 @classmethod
1516 def setUpClass (cls ):
16- cls .server_access_info = ServerAccessInfo (base_url = "http://127.0.0.1:8000" )
17+ cls .server_access_infos = ServerAccessInfo (
18+ base_url = "http://127.0.0.1:8000" , label = "Test"
19+ )
1720 cls .client_cmds = """
1821 python -m sglang.bench_serving \
1922 --backend sglang-oai
@@ -28,10 +31,10 @@ def setUpClass(cls):
2831 --num-prompt 40
2932 """
3033
31- def run_single_cmd (self , server_access_info , client_cmds , ** kwargs ):
34+ def run_single_cmd (self , server_access_infos , client_cmds , ** kwargs ):
3235 with tempfile .TemporaryDirectory () as output_dir :
3336 client_cmp (
34- server_access_info = server_access_info ,
37+ server_access_infos = server_access_infos ,
3538 client_cmds = client_cmds ,
3639 input_features = input_features ,
3740 output_metrics = output_metrics ,
@@ -44,106 +47,85 @@ def run_single_cmd(self, server_access_info, client_cmds, **kwargs):
4447
4548 ################## BASIC RUN ##########################
4649 def test_single_run (self ):
47- self .run_single_cmd (self .server_access_info , self .client_cmds )
50+ self .run_single_cmd (self .server_access_infos , self .client_cmds )
4851
4952 def test_multiple_run (self ):
50- self .run_single_cmd ([self .server_access_info ] * 2 , [self .client_cmds ] * 3 )
53+ self .run_single_cmd ([self .server_access_infos ] * 2 , [self .client_cmds ] * 3 )
5154
52- self .run_single_cmd ([self .server_access_info ] * 2 , self .client_cmds )
55+ self .run_single_cmd ([self .server_access_infos ] * 2 , self .client_cmds )
5356
54- self .run_single_cmd (self .server_access_info , [self .client_cmds ] * 2 )
57+ self .run_single_cmd (self .server_access_infos , [self .client_cmds ] * 2 )
5558
5659 def test_n_run (self ):
57- self .run_single_cmd (self .server_access_info , self .client_cmds , n = 3 )
60+ self .run_single_cmd (self .server_access_infos , self .client_cmds , n = 3 )
5861
59- self .run_single_cmd ([self .server_access_info ] * 2 , [self .client_cmds ] * 3 , n = 3 )
62+ self .run_single_cmd ([self .server_access_infos ] * 2 , [self .client_cmds ] * 3 , n = 3 )
6063
61- self .run_single_cmd ([self .server_access_info ] * 2 , self .client_cmds , n = 3 )
64+ self .run_single_cmd ([self .server_access_infos ] * 2 , self .client_cmds , n = 3 )
6265
63- self .run_single_cmd (self .server_access_info , [self .client_cmds ] * 2 , n = 3 )
66+ self .run_single_cmd (self .server_access_infos , [self .client_cmds ] * 2 , n = 3 )
6467
6568 self .run_single_cmd (
66- self .server_access_info , [self .client_cmds ] * 2 , n = 3 , only_last = True
69+ self .server_access_infos , [self .client_cmds ] * 2 , n = 3 , only_last = True
6770 )
6871
6972 ################## LABEL SETTING ##########################
7073 def test_client_labels (self ):
7174 self .run_single_cmd (
72- self .server_access_info , self .client_cmds , client_labels = ["client1" ]
75+ self .server_access_infos , self .client_cmds , client_labels = ["client1" ]
7376 )
7477
7578 self .run_single_cmd (
76- self .server_access_info , self .client_cmds , client_labels = "client1"
79+ self .server_access_infos , self .client_cmds , client_labels = "client1"
7780 )
7881
7982 self .run_single_cmd (
80- self .server_access_info , [self .client_cmds ] * 2 , client_labels = "client1"
83+ self .server_access_infos , [self .client_cmds ] * 2 , client_labels = "client1"
8184 )
8285
8386 self .run_single_cmd (
84- self .server_access_info ,
87+ self .server_access_infos ,
8588 [self .client_cmds ] * 2 ,
8689 client_labels = ["client1" ] * 2 ,
8790 )
8891
8992 @unittest .expectedFailure
9093 def test_failed_client_labels (self ):
9194 self .run_single_cmd (
92- self .server_access_info ,
95+ self .server_access_infos ,
9396 [self .client_cmds ] * 2 ,
9497 client_labels = ["client1" ] * 3 ,
9598 )
9699
97- def test_server_labels (self ):
98- self .run_single_cmd (
99- self .server_access_info , self .client_cmds , server_labels = "server_label"
100- )
101- self .run_single_cmd (
102- self .server_access_info , self .client_cmds , server_labels = ["server_label" ]
103- )
104- self .run_single_cmd (
105- [self .server_access_info ] * 2 ,
106- self .client_cmds ,
107- server_labels = ["server_label" ] * 2 ,
108- )
109-
110- @unittest .expectedFailure
111- def test_failed_server_labels (self ):
112- self .run_single_cmd (
113- [self .server_access_info ] * 2 ,
114- self .client_cmds ,
115- server_labels = ["server_label" ] * 3 ,
116- )
117-
118100 ################## EXPECTED FAIL ###########################
119101 @unittest .expectedFailure
120102 def test_failed_host (self ):
121103 self .run_single_cmd (
122- self .server_access_info , self .client_cmds + " --host 127.0.0.1"
104+ self .server_access_infos , self .client_cmds + " --host 127.0.0.1"
123105 )
124106
125107 @unittest .expectedFailure
126108 def test_failed_port (self ):
127- self .run_single_cmd (self .server_access_info , self .client_cmds + " --port 8888" )
109+ self .run_single_cmd (self .server_access_infos , self .client_cmds + " --port 8888" )
128110
129111 @unittest .expectedFailure
130112 def test_failed_base_url (self ):
131113 self .run_single_cmd (
132- self .server_access_info ,
114+ self .server_access_infos ,
133115 self .client_cmds + " --base-url http://127.0.0.1:8888" ,
134116 )
135117
136118 @unittest .expectedFailure
137119 def test_failed_output_file (self ):
138120 self .run_single_cmd (
139- self .server_access_info , self .client_cmds + " --output-file output.jsonl"
121+ self .server_access_infos , self .client_cmds + " --output-file output.jsonl"
140122 )
141123
142124 ################## DISABLE FEATURE ##########################
143125 def test_disable_warmup (self ):
144126 with tempfile .TemporaryDirectory () as output_dir :
145127 client_cmp (
146- server_access_info = self .server_access_info ,
128+ server_access_infos = self .server_access_infos ,
147129 client_cmds = self .client_cmds ,
148130 input_features = input_features ,
149131 output_metrics = output_metrics ,
@@ -155,7 +137,7 @@ def test_disable_warmup(self):
155137 def test_disable_csv (self ):
156138 with tempfile .TemporaryDirectory () as output_dir :
157139 client_cmp (
158- server_access_info = self .server_access_info ,
140+ server_access_infos = self .server_access_infos ,
159141 client_cmds = self .client_cmds ,
160142 input_features = input_features ,
161143 output_metrics = output_metrics ,
@@ -167,7 +149,7 @@ def test_disable_csv(self):
167149 def test_disable_md_table (self ):
168150 with tempfile .TemporaryDirectory () as output_dir :
169151 client_cmp (
170- server_access_info = self .server_access_info ,
152+ server_access_infos = self .server_access_infos ,
171153 client_cmds = self .client_cmds ,
172154 input_features = input_features ,
173155 output_metrics = output_metrics ,
@@ -179,7 +161,7 @@ def test_disable_md_table(self):
179161 def test_disable_plot (self ):
180162 with tempfile .TemporaryDirectory () as output_dir :
181163 client_cmp (
182- server_access_info = self .server_access_info ,
164+ server_access_infos = self .server_access_infos ,
183165 client_cmds = self .client_cmds ,
184166 input_features = input_features ,
185167 output_metrics = output_metrics ,
0 commit comments