Skip to content

Commit 14befa9

Browse files
committed
Update check.py
1 parent 15996b4 commit 14befa9

File tree

1 file changed

+189
-15
lines changed

1 file changed

+189
-15
lines changed

ai_infra_bench/check.py

Lines changed: 189 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,127 @@
66

77
from ai_infra_bench.utils import is_ci
88

9+
try:
10+
import sglang
11+
12+
is_sglang_available = True
13+
except ImportError:
14+
is_sglang_available = False
15+
916
logger = logging.getLogger(__name__)
1017

18+
DEFAULT_BENCH_SERVING_PATH = "/tmp/ai_infra_bench/bench_serving.py"
19+
20+
21+
def ensure_bench_serving_available() -> None:
22+
"""
23+
Automatically download bench_serving.py if sglang is not available.
24+
Also installs required dependencies.
25+
"""
26+
if is_sglang_available:
27+
return
28+
29+
if os.path.exists(DEFAULT_BENCH_SERVING_PATH):
30+
logger.info(f"bench_serving.py already exists at {DEFAULT_BENCH_SERVING_PATH}")
31+
return
32+
33+
logger.info("sglang is not available, downloading bench_serving.py...")
34+
try:
35+
import requests
36+
37+
raw_url = "https://raw.githubusercontent.com/sgl-project/sglang/main/python/sglang/bench_serving.py"
38+
response = requests.get(raw_url, timeout=30)
39+
response.raise_for_status()
40+
41+
os.makedirs(os.path.dirname(DEFAULT_BENCH_SERVING_PATH), exist_ok=True)
42+
with open(DEFAULT_BENCH_SERVING_PATH, "w") as f:
43+
f.write(response.text)
44+
logger.info(
45+
f"Successfully downloaded bench_serving.py to {DEFAULT_BENCH_SERVING_PATH}"
46+
)
47+
48+
# Install dependencies required by bench_serving.py
49+
install_bench_serving_dependencies()
50+
51+
except Exception as e:
52+
logger.error(
53+
f"Failed to download bench_serving.py from {raw_url}: {e}. "
54+
f"Please ensure you have internet access or manually download the file."
55+
)
56+
raise
57+
58+
59+
def install_bench_serving_dependencies() -> None:
60+
"""
61+
Parse bench_serving.py and install missing dependencies.
62+
"""
63+
import ast
64+
import subprocess
65+
import sys
66+
67+
logger.info("Installing bench_serving.py dependencies...")
68+
69+
try:
70+
with open(DEFAULT_BENCH_SERVING_PATH, "r") as f:
71+
tree = ast.parse(f.read())
72+
73+
# Extract all imported modules
74+
imports = set()
75+
for node in ast.walk(tree):
76+
if isinstance(node, ast.Import):
77+
for alias in node.names:
78+
imports.add(alias.name.split(".")[0])
79+
elif isinstance(node, ast.ImportFrom):
80+
if node.module:
81+
imports.add(node.module.split(".")[0])
82+
83+
# Standard library modules to skip
84+
stdlib_modules = set(sys.stdlib_module_names)
85+
86+
# Filter out standard library modules
87+
third_party_imports = {mod for mod in imports if mod not in stdlib_modules}
88+
89+
if not third_party_imports:
90+
logger.info("No third-party dependencies found to install")
91+
return
92+
93+
logger.info(
94+
"Found dependencies to install: %s",
95+
", ".join(sorted(third_party_imports)),
96+
)
97+
98+
# Try to install each dependency
99+
for module in sorted(third_party_imports):
100+
try:
101+
__import__(module)
102+
logger.debug(f" {module} is already installed")
103+
except ImportError:
104+
# Map module names to pip package names (handle common cases)
105+
package_map = {
106+
"cv2": "opencv-python",
107+
"PIL": "Pillow",
108+
"yaml": "PyYAML",
109+
"bs4": "beautifulsoup4",
110+
}
111+
package_name = package_map.get(module, module)
112+
113+
logger.info(" Installing %s...", package_name)
114+
subprocess.check_call(
115+
[sys.executable, "-m", "pip", "install", package_name],
116+
stdout=subprocess.DEVNULL,
117+
stderr=subprocess.DEVNULL,
118+
)
119+
logger.debug(f" {package_name} installed successfully")
120+
121+
logger.info("All dependencies installed successfully")
122+
123+
except Exception as e:
124+
logger.warning(
125+
f"Failed to automatically install dependencies: {e}. "
126+
f"Please manually install any missing dependencies if the script fails."
127+
)
128+
129+
11130
SGLANG_KEYS = [
12131
"backend",
13132
"dataset_name",
@@ -16,21 +135,21 @@
16135
"sharegpt_output_len",
17136
"random_input_len",
18137
"random_output_len",
19-
"random_range _ratio",
138+
"random_range_ratio",
20139
"duration",
21140
"completed",
22141
"total_input_tokens",
23142
"total_output_tokens",
24143
"total_output_tokens_retokenized",
25144
"request_throughput",
26-
"input_through put",
145+
"input_throughput",
27146
"output_throughput",
28147
"mean_e2e_latency_ms",
29148
"median_e2e_latency_ms",
30149
"std_e2e_latency_ms",
31150
"p99_e2e_latency_ms",
32151
"mean_ttft_ms",
33-
"median_ttft_ms ",
152+
"median_ttft_ms",
34153
"std_ttft_ms",
35154
"p99_ttft_ms",
36155
"mean_tpot_ms",
@@ -40,7 +159,7 @@
40159
"mean_itl_ms",
41160
"median_itl_ms",
42161
"std_itl_ms",
43-
"p95_it l_ms",
162+
"p95_itl_ms",
44163
"p99_itl_ms",
45164
"concurrency",
46165
"accept_length",
@@ -53,7 +172,7 @@ def check_dir(output_dir: str, full_data_json_path):
53172
for an action (delete or rename). It re-prompts on invalid input.
54173
"""
55174
if is_ci():
56-
os.makedirs(os.path.join(output_dir, full_data_json_path))
175+
os.makedirs(os.path.join(output_dir, full_data_json_path), exist_ok=True)
57176
return output_dir
58177

59178
if os.path.exists(output_dir):
@@ -85,10 +204,11 @@ def check_dir(output_dir: str, full_data_json_path):
85204
output_dir = input("New directory name: ").strip()
86205
os.makedirs(output_dir)
87206
logger.info(f"New directory created: '{output_dir}'.")
207+
break
88208
elif option == "4":
89209
exit(0)
90210
else:
91-
logger.warning("Invalid option. Please enter '1', '2' or '3'.")
211+
logger.warning("Invalid option. Please enter '1', '2', '3', or '4'.")
92212
else:
93213
# If the directory does not exist, create it directly
94214
os.makedirs(output_dir)
@@ -98,6 +218,40 @@ def check_dir(output_dir: str, full_data_json_path):
98218
return output_dir
99219

100220

221+
def check_content_client_cmds(client_cmds: List[List[str]]) -> None:
222+
if is_sglang_available:
223+
for client_cmd in client_cmds:
224+
for cmd in client_cmd:
225+
assert any(
226+
cmd.strip().startswith(p)
227+
for p in [
228+
"python -m sglang.bench_serving",
229+
"python3 -m sglang.bench_serving",
230+
]
231+
), f"Each client_cmd must start with 'python -m sglang.bench_serving' or 'python3 -m sglang.bench_serving', but found {cmd=}"
232+
else:
233+
# Ensure bench_serving is available if sglang is not installed
234+
ensure_bench_serving_available()
235+
for cmd_list_idx, client_cmd in enumerate(client_cmds):
236+
for cmd_idx, cmd in enumerate(client_cmd):
237+
if cmd.startswith("python -m sglang.bench_serving"):
238+
cmd = cmd.replace(
239+
"python -m sglang.bench_serving",
240+
f"python {DEFAULT_BENCH_SERVING_PATH}",
241+
)
242+
elif cmd.startswith("python3 -m sglang.bench_serving"):
243+
cmd = cmd.replace(
244+
"python3 -m sglang.bench_serving",
245+
f"python3 {DEFAULT_BENCH_SERVING_PATH}",
246+
)
247+
else:
248+
raise ValueError(
249+
f"Each client_cmd must start with 'python -m sglang.bench_serving' or 'python3 -m sglang.bench_serving', but found {cmd=}"
250+
)
251+
client_cmd[cmd_idx] = cmd
252+
client_cmds[cmd_list_idx] = client_cmd
253+
254+
101255
def check_content_server_client_cmds(
102256
server_cmds: List[str], client_cmds: List[List[str]]
103257
) -> None:
@@ -111,14 +265,34 @@ def check_content_server_client_cmds(
111265
), f"Each server_cmd must start with 'python -m sglang.launch_server' or 'python3 -m sglang.launch_server', but found {cmd=}"
112266

113267
for client_cmd in client_cmds:
114-
for cmd in client_cmd:
115-
assert any(
116-
cmd.strip().startswith(p)
117-
for p in [
118-
"python -m sglang.bench_serving",
119-
"python3 -m sglang.bench_serving",
120-
]
121-
), f"Each client_cmd must start with 'python -m sglang.bench_serving' or 'python3 -m sglang.bench_serving', but found {cmd=}"
268+
if is_sglang_available:
269+
for cmd in client_cmd:
270+
assert any(
271+
cmd.strip().startswith(p)
272+
for p in [
273+
"python -m sglang.bench_serving",
274+
"python3 -m sglang.bench_serving",
275+
]
276+
), f"Each client_cmd must start with 'python -m sglang.bench_serving' or 'python3 -m sglang.bench_serving', but found {cmd=}"
277+
else:
278+
# Ensure bench_serving is available if sglang is not installed
279+
ensure_bench_serving_available()
280+
for cmd_idx, cmd in enumerate(client_cmd):
281+
if cmd.startswith("python -m sglang.bench_serving"):
282+
cmd = cmd.replace(
283+
"python -m sglang.bench_serving",
284+
f"python {DEFAULT_BENCH_SERVING_PATH}",
285+
)
286+
elif cmd.startswith("python3 -m sglang.bench_serving"):
287+
cmd = cmd.replace(
288+
"python3 -m sglang.bench_serving",
289+
f"python3 {DEFAULT_BENCH_SERVING_PATH}",
290+
)
291+
else:
292+
raise ValueError(
293+
f"Each client_cmd must start with 'python -m sglang.bench_serving' or 'python3 -m sglang.bench_serving', but found {cmd=}"
294+
)
295+
client_cmd[cmd_idx] = cmd
122296

123297

124298
def check_values_in_features_metrics(input_features, output_metrics):
@@ -133,7 +307,7 @@ def check_values_in_features_metrics(input_features, output_metrics):
133307

134308
def check_param_in_cmd(param: str, cmds: List[str]):
135309
for cmd in cmds:
136-
assert param not in cmd, f"{cmd=} should not contain '{param}''"
310+
assert param not in cmd, f"{cmd=} should not contain '{param}'"
137311

138312

139313
def check_str_list_str(cmds: str | List[str]):

0 commit comments

Comments
 (0)