Skip to content

Commit f4867ec

Browse files
authored
[Test] Add Omni Model Performance Benchmark Test (vllm-project#1321)
Signed-off-by: yenuo26 <410167048@qq.com> Signed-off-by: wangyu <53896905+yenuo26@users.noreply.github.com>
1 parent 00bd07b commit f4867ec

File tree

5 files changed

+696
-38
lines changed

5 files changed

+696
-38
lines changed

.buildkite/test-nightly.yaml

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,3 +55,40 @@ steps:
5555
- "HF_HOME=/fsx/hf_cache"
5656
volumes:
5757
- "/fsx/hf_cache:/fsx/hf_cache"
58+
59+
60+
- label: "Omni Model Perf Test"
61+
timeout_in_minutes: 120
62+
depends_on: image-build
63+
if: build.env("NIGHTLY") == "1"
64+
commands:
65+
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
66+
- pytest -s -v tests/perf/scripts/run_benchmark.py
67+
agents:
68+
queue: "mithril-h100-pool"
69+
plugins:
70+
- kubernetes:
71+
podSpec:
72+
containers:
73+
- image: public.ecr.aws/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT
74+
resources:
75+
limits:
76+
nvidia.com/gpu: 2
77+
volumeMounts:
78+
- name: devshm
79+
mountPath: /dev/shm
80+
- name: hf-cache
81+
mountPath: /root/.cache/huggingface
82+
env:
83+
- name: HF_HOME
84+
value: /root/.cache/huggingface
85+
nodeSelector:
86+
node.kubernetes.io/instance-type: gpu-h100-sxm
87+
volumes:
88+
- name: devshm
89+
emptyDir:
90+
medium: Memory
91+
- name: hf-cache
92+
hostPath:
93+
path: /mnt/hf-cache
94+
type: DirectoryOrCreate

tests/conftest.py

Lines changed: 37 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -828,7 +828,7 @@ def delete_by_path(config_dict: dict, path: str) -> None:
828828
# Find stage by ID
829829
target_stage = None
830830
for stage in stage_args:
831-
if stage.get("stage_id") == stage_id:
831+
if stage.get("stage_id") == int(stage_id):
832832
target_stage = stage
833833
break
834834

@@ -847,43 +847,42 @@ def delete_by_path(config_dict: dict, path: str) -> None:
847847
# Delete entire key
848848
del config[key]
849849

850-
if updates:
851-
# Apply updates
852-
for key, value in updates.items():
853-
if key == "stage_args":
854-
if value and isinstance(value, dict):
855-
stage_args = config.get("stage_args", [])
856-
if not stage_args:
857-
raise ValueError("stage_args does not exist in config")
858-
859-
for stage_id, stage_updates in value.items():
860-
# Find stage by ID
861-
target_stage = None
862-
for stage in stage_args:
863-
if stage.get("stage_id") == stage_id:
864-
target_stage = stage
865-
break
866-
867-
if target_stage is None:
868-
available_ids = [s.get("stage_id") for s in stage_args if "stage_id" in s]
869-
raise KeyError(f"Stage ID {stage_id} not found, available: {available_ids}")
870-
871-
# Apply updates to this stage
872-
for path, val in stage_updates.items():
873-
# Check if this is a simple key (not dot-separated)
874-
# Example: 'engine_input_source' vs 'engine_args.max_model_len'
875-
if "." not in path:
876-
# Direct key assignment (e.g., updating a list value)
877-
target_stage[path] = val
878-
else:
879-
# Dot-separated path (e.g., nested dict access)
880-
apply_update(target_stage, path, val)
881-
elif "." in key:
882-
# Apply using dot-separated path
883-
apply_update(config, key, value)
884-
else:
885-
# Direct top-level key
886-
config[key] = value
850+
# Apply updates
851+
for key, value in updates.items():
852+
if key == "stage_args":
853+
if value and isinstance(value, dict):
854+
stage_args = config.get("stage_args", [])
855+
if not stage_args:
856+
raise ValueError("stage_args does not exist in config")
857+
858+
for stage_id, stage_updates in value.items():
859+
# Find stage by ID
860+
target_stage = None
861+
for stage in stage_args:
862+
if stage.get("stage_id") == int(stage_id):
863+
target_stage = stage
864+
break
865+
866+
if target_stage is None:
867+
available_ids = [s.get("stage_id") for s in stage_args if "stage_id" in s]
868+
raise KeyError(f"Stage ID {stage_id} not found, available: {available_ids}")
869+
870+
# Apply updates to this stage
871+
for path, val in stage_updates.items():
872+
# Check if this is a simple key (not dot-separated)
873+
# Example: 'engine_input_source' vs 'engine_args.max_model_len'
874+
if "." not in path:
875+
# Direct key assignment (e.g., updating a list value)
876+
target_stage[path] = val
877+
else:
878+
# Dot-separated path (e.g., nested dict access)
879+
apply_update(target_stage, path, val)
880+
elif "." in key:
881+
# Apply using dot-separated path
882+
apply_update(config, key, value)
883+
else:
884+
# Direct top-level key
885+
config[key] = value
887886

888887
# Save to new file with timestamp
889888
timestamp = int(time.time())
Lines changed: 255 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,255 @@
1+
import os
2+
3+
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
4+
os.environ["VLLM_TEST_CLEAN_GPU_MEMORY"] = "0"
5+
6+
import json
7+
import subprocess
8+
import threading
9+
from datetime import datetime
10+
from pathlib import Path
11+
from typing import Any
12+
13+
import pytest
14+
15+
from tests.conftest import OmniServer, modify_stage_config
16+
17+
18+
def load_configs(config_path: str) -> list[dict[str, Any]]:
19+
try:
20+
abs_path = Path(config_path).resolve()
21+
with open(abs_path, encoding="utf-8") as f:
22+
configs = json.load(f)
23+
24+
return configs
25+
26+
except json.JSONDecodeError as e:
27+
raise ValueError(f"JSON parsing error: {str(e)}")
28+
except FileNotFoundError:
29+
raise ValueError(f"Configuration file not found: {config_path}")
30+
except Exception as e:
31+
raise RuntimeError(f"Failed to load configuration file: {str(e)}")
32+
33+
34+
def modify_stage(default_path, updates, deletes):
35+
kwargs = {}
36+
if updates is not None:
37+
kwargs["updates"] = updates
38+
if deletes is not None:
39+
kwargs["deletes"] = deletes
40+
if kwargs:
41+
path = modify_stage_config(default_path, **kwargs)
42+
else:
43+
path = default_path
44+
45+
return path
46+
47+
48+
def create_unique_server_params(configs: list[dict[str, Any]]) -> list[tuple[str, str, str]]:
49+
unique_params = set()
50+
for config in configs:
51+
test_name = config["test_name"]
52+
model = config["server_params"]["model"]
53+
stage_config_name = config["server_params"]["stage_config_name"]
54+
stage_config_path = str(Path(__file__).parent.parent / "stage_configs" / stage_config_name)
55+
delete = config["server_params"].get("delete", None)
56+
update = config["server_params"].get("update", None)
57+
stage_config_path = modify_stage(stage_config_path, update, delete)
58+
unique_params.add((test_name, model, stage_config_path))
59+
60+
return list(unique_params)
61+
62+
63+
def create_test_parameter_mapping(configs: list[dict[str, Any]]) -> dict[str, dict]:
64+
mapping = {}
65+
for config in configs:
66+
test_name = config["test_name"]
67+
if test_name not in mapping:
68+
mapping[test_name] = {
69+
"test_name": test_name,
70+
"benchmark_params": [],
71+
}
72+
mapping[test_name]["benchmark_params"].extend(config["benchmark_params"])
73+
return mapping
74+
75+
76+
CONFIG_FILE_PATH = str(Path(__file__).parent.parent / "tests" / "test.json")
77+
BENCHMARK_CONFIGS = load_configs(CONFIG_FILE_PATH)
78+
79+
80+
test_params = create_unique_server_params(BENCHMARK_CONFIGS)
81+
server_to_benchmark_mapping = create_test_parameter_mapping(BENCHMARK_CONFIGS)
82+
83+
_omni_server_lock = threading.Lock()
84+
85+
86+
@pytest.fixture(scope="module")
87+
def omni_server(request):
88+
"""Start vLLM-Omni server as a subprocess with actual model weights.
89+
Uses session scope so the server starts only once for the entire test session.
90+
Multi-stage initialization can take 10-20+ minutes.
91+
"""
92+
with _omni_server_lock:
93+
test_name, model, stage_config_path = request.param
94+
95+
print(f"Starting OmniServer with test: {test_name}, model: {model}")
96+
97+
with OmniServer(model, ["--stage-configs-path", stage_config_path, "--stage-init-timeout", "120"]) as server:
98+
print("OmniServer started successfully")
99+
yield server
100+
print("OmniServer stopping...")
101+
102+
print("OmniServer stopped")
103+
104+
105+
def run_benchmark(args: list, test_name: str, flow, dataset_name: str, num_prompt) -> Any:
106+
"""Generate synthetic image with random values."""
107+
current_dt = datetime.now().strftime("%Y%m%d-%H%M%S")
108+
result_filename = f"result_{test_name}_{dataset_name}_{flow}_{num_prompt}_{current_dt}.json"
109+
if "--result-filename" in args:
110+
print(f"The result file will be overwritten by {result_filename}")
111+
command = (
112+
["vllm", "bench", "serve", "--omni"]
113+
+ args
114+
+ [
115+
"--backend",
116+
"openai-chat-omni",
117+
"--endpoint",
118+
"/v1/chat/completions",
119+
"--save-result",
120+
"--result-filename",
121+
result_filename,
122+
]
123+
)
124+
process = subprocess.Popen(
125+
command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, bufsize=1, universal_newlines=True
126+
)
127+
128+
for line in iter(process.stdout.readline, ""):
129+
print(line, end=" ")
130+
131+
for line in iter(process.stderr.readline, ""):
132+
print(line, end=" ")
133+
134+
if "--result-dir" in args:
135+
index = args.index("--result-dir")
136+
result_dir = args[index + 1]
137+
else:
138+
result_dir = "./"
139+
140+
with open(os.path.join(result_dir, result_filename), encoding="utf-8") as f:
141+
result = json.load(f)
142+
return result
143+
144+
145+
def get_benchmark_params_for_server(test_name: str) -> list:
146+
if test_name not in server_to_benchmark_mapping:
147+
return []
148+
return server_to_benchmark_mapping[test_name]["benchmark_params"]
149+
150+
151+
def create_benchmark_indices():
152+
indices = []
153+
for test_name, config_data in server_to_benchmark_mapping.items():
154+
params_list = config_data["benchmark_params"]
155+
for idx in range(len(params_list)):
156+
indices.append((test_name, idx))
157+
return indices
158+
159+
160+
benchmark_indices = create_benchmark_indices()
161+
162+
163+
@pytest.fixture(params=benchmark_indices)
164+
def benchmark_params(request, omni_server):
165+
"""Benchmark parameters fixture with proper parametrization"""
166+
test_name, param_index = request.param
167+
all_params = get_benchmark_params_for_server(test_name)
168+
169+
if not all_params:
170+
raise ValueError(f"No benchmark parameters found for test: {test_name}")
171+
172+
if param_index >= len(all_params):
173+
raise ValueError(f"No benchmark parameters found for index {param_index} in test: {test_name}")
174+
175+
return {"test_name": test_name, "params": all_params[param_index]}
176+
177+
178+
def assert_result(result, params, num_prompt):
179+
assert result["completed"] == num_prompt, "Request failures exist"
180+
baseline_data = params.get("baseline", {})
181+
for metric_name, baseline_value in baseline_data.items():
182+
current_value = result[metric_name]
183+
if "throughput" in metric_name:
184+
assert current_value >= baseline_value, f"{metric_name}: {current_value} < {baseline_value}"
185+
else:
186+
assert current_value <= baseline_value, f"{metric_name}: {current_value} > {baseline_value}"
187+
188+
189+
@pytest.mark.parametrize("omni_server", test_params, indirect=True)
190+
@pytest.mark.parametrize("benchmark_params", benchmark_indices, indirect=True)
191+
def test_performance_benchmark(omni_server, benchmark_params):
192+
test_name = benchmark_params["test_name"]
193+
params = benchmark_params["params"]
194+
dataset_name = params.get("dataset_name", "")
195+
196+
host = omni_server.host
197+
port = omni_server.port
198+
model = omni_server.model
199+
200+
print(f"Running benchmark for model: {model}")
201+
print(f"Benchmark parameters: {benchmark_params}")
202+
203+
def to_list(value, default=None):
204+
if value is None:
205+
return [] if default is None else [default]
206+
return [value] if not isinstance(value, (list, tuple)) else list(value)
207+
208+
qps_list = to_list(params.get("request_rate"))
209+
num_prompt_list = to_list(params.get("num_prompts"))
210+
max_concurrency_list = to_list(params.get("max_concurrency"))
211+
212+
max_len = max(len(qps_list), len(max_concurrency_list))
213+
if len(num_prompt_list) == 1 and max_len > 1:
214+
num_prompt_list = num_prompt_list * max_len
215+
elif max_len == 1 and len(num_prompt_list) > 1:
216+
if len(qps_list) == 1:
217+
qps_list = qps_list * len(num_prompt_list)
218+
if len(max_concurrency_list) == 1:
219+
max_concurrency_list = max_concurrency_list * len(num_prompt_list)
220+
max_len = max(len(qps_list), len(max_concurrency_list))
221+
elif len(num_prompt_list) != max_len and max_len > 0:
222+
raise ValueError("The number of prompts does not match the QPS or max_concurrency")
223+
224+
args = ["--host", host, "--port", str(port)]
225+
exclude_keys = {"request_rate", "baseline", "num_prompts", "max_concurrency"}
226+
227+
for key, value in params.items():
228+
if key in exclude_keys or value is None:
229+
continue
230+
231+
arg_name = f"--{key.replace('_', '-')}"
232+
233+
if isinstance(value, bool) and value:
234+
args.append(arg_name)
235+
elif isinstance(value, dict):
236+
json_str = json.dumps(value, ensure_ascii=False, separators=(",", ":"))
237+
args.extend([arg_name, json_str])
238+
elif not isinstance(value, bool):
239+
args.extend([arg_name, str(value)])
240+
241+
# QPS test
242+
for qps, num_prompt in zip(qps_list, num_prompt_list):
243+
args = args + ["--request-rate", str(qps), "--num-prompts", str(num_prompt)]
244+
result = run_benchmark(
245+
args=args, test_name=test_name, flow=qps, dataset_name=dataset_name, num_prompt=num_prompt
246+
)
247+
assert_result(result, params, num_prompt=num_prompt)
248+
249+
# concurrency test
250+
for concurrency, num_prompt in zip(max_concurrency_list, num_prompt_list):
251+
args = args + ["--max-concurrency", str(concurrency), "--num-prompts", str(num_prompt), "--request-rate", "inf"]
252+
result = run_benchmark(
253+
args=args, test_name=test_name, flow=concurrency, dataset_name=dataset_name, num_prompt=num_prompt
254+
)
255+
assert_result(result, params, num_prompt=num_prompt)

0 commit comments

Comments
 (0)