Skip to content

Commit 84b5076

Browse files
committed
Update
1 parent 39bc9ec commit 84b5076

File tree

4 files changed

+210
-62
lines changed

4 files changed

+210
-62
lines changed

ai_infra_bench/modes/cmp.py

Lines changed: 131 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -71,11 +71,11 @@ def cmp_plot(data, input_features, metrics, labels, output_dir):
7171
print("Ploting graphs DONE")
7272

7373

74-
@enter_decorate("CMP EXPORT TBALE", filename=TABLE_NAME)
74+
@enter_decorate("CMP EXPORT TABLE", filename=TABLE_NAME)
7575
def cmp_export_table(
7676
all_clients_results: List[List[Dict]],
7777
input_features: List[str],
78-
output_metrics: List[Dict],
78+
output_metrics: List[str],
7979
num_clients: int,
8080
num_servers: int,
8181
output_dir: str,
@@ -84,45 +84,140 @@ def cmp_export_table(
8484
if not all_clients_results or not all_clients_results[0]:
8585
raise ValueError("No data available to export.")
8686

87-
if server_labels[0] is None:
87+
if server_labels is None or server_labels[0] is None:
8888
server_labels = [f"server_{i + 1}" for i in range(num_servers)]
8989

90-
# header
91-
header_cells = input_features + [" - "]
92-
for output_metric in output_metrics:
93-
header_cells += [output_metric] + [" - "] * (len(server_labels) - 1)
94-
header_row = "| " + " | ".join(map(str, header_cells)) + " |"
90+
# --- 1. 动态构建表头 ---
91+
# 将 input_features 组合成标题,例如: "Config (input_len / output_len / rate)"
92+
config_header_name = f"Config ({' / '.join(input_features)})"
9593

96-
# sub header
97-
sub_header_cells = [" - "] * (len(input_features) + 1) + server_labels * len(
98-
output_metrics
99-
)
100-
sub_header_row = "| " + " | ".join(map(str, sub_header_cells)) + " |"
94+
header_cells = [config_header_name, "Metric"] + server_labels
95+
if num_servers == 2:
96+
header_cells.append("Diff (%)")
10197

98+
header_row = "| " + " | ".join(header_cells) + " |"
10299
separator_row = "| " + " | ".join(["---"] * len(header_cells)) + " |"
103-
lines = [header_row, sub_header_row, separator_row]
100+
lines = [header_row, separator_row]
104101

102+
# --- 2. 遍历每一个配置 (Client Config) ---
105103
for client_idx in range(num_clients):
106-
#
107-
row_values = []
108-
109-
all_server_metrics = []
110-
for server_idx in range(num_servers):
111-
server_metrics = []
112-
idx = client_idx + server_idx * num_clients
113-
row_results = all_clients_results[idx]
114-
if server_idx == 0:
115-
for feature in input_features:
116-
row_values.append(f"{row_results[0][feature]:.2f}")
117-
row_values.append("-")
118-
for metric in output_metrics:
119-
server_metrics.append(avg_std_strf(metric, row_results, precision=2))
120-
all_server_metrics.append(server_metrics)
121-
122-
for i in range(len(output_metrics)):
123-
for j in range(num_servers):
124-
row_values.append(all_server_metrics[j][i])
125-
lines.append("| " + " | ".join(row_values) + " |")
126-
127-
with open(os.path.join(output_dir, TABLE_NAME), mode="w", encoding="utf-8") as f:
104+
105+
# 动态提取当前配置下所有 feature 的值
106+
# 索引逻辑: client_idx 对应第一个 server 的该配置结果
107+
first_server_res_list = all_clients_results[client_idx]
108+
first_sample = first_server_res_list[0]
109+
110+
config_val_list = []
111+
for feat in input_features:
112+
val = first_sample.get(feat, "N/A")
113+
# 格式化数值:如果是浮点数保留两位,否则转字符串
114+
if isinstance(val, float):
115+
config_val_list.append(f"{val:.2f}")
116+
else:
117+
config_val_list.append(str(val))
118+
119+
# 拼接后的配置字符串,例如 "1200.00 / 800.00 / 4.00"
120+
config_str = " / ".join(config_val_list)
121+
122+
# --- 3. 遍历每一个指标 (Metric) ---
123+
for m_idx, metric in enumerate(output_metrics):
124+
row_values = []
125+
126+
# 第一列:仅在指标块的第一行显示配置
127+
if m_idx == 0:
128+
row_values.append(f"**{config_str}**")
129+
else:
130+
row_values.append(" ")
131+
132+
# 第二列:指标名称
133+
row_values.append(metric)
134+
135+
# 后面几列:各个 Server 的数值
136+
numerical_means = []
137+
for s_idx in range(num_servers):
138+
idx = client_idx + s_idx * num_clients
139+
res_list = all_clients_results[idx]
140+
141+
# 使用你原有的格式化函数获取 "均值 ± 标准差"
142+
display_str = avg_std_strf(metric, res_list, precision=2)
143+
row_values.append(display_str)
144+
145+
# 为计算 Diff 提取纯数值均值
146+
try:
147+
m_val = sum(r[metric] for r in res_list) / len(res_list)
148+
numerical_means.append(m_val)
149+
except:
150+
numerical_means.append(None)
151+
152+
# 最后一列:动态计算两个 Server 间的差异
153+
if num_servers == 2:
154+
v1, v2 = numerical_means[0], numerical_means[1]
155+
if v1 is not None and v2 is not None and v1 != 0:
156+
diff = (v2 - v1) / v1 * 100
157+
row_values.append(f"{diff:+.2f}%")
158+
else:
159+
row_values.append("-")
160+
161+
lines.append("| " + " | ".join(row_values) + " |")
162+
163+
# --- 4. 写入文件 ---
164+
output_path = os.path.join(output_dir, TABLE_NAME)
165+
with open(output_path, mode="w", encoding="utf-8") as f:
128166
f.write("\n".join(lines))
167+
168+
169+
# @enter_decorate("CMP EXPORT TBALE", filename=TABLE_NAME)
170+
# def cmp_export_table(
171+
# all_clients_results: List[List[Dict]],
172+
# input_features: List[str],
173+
# output_metrics: List[Dict],
174+
# num_clients: int,
175+
# num_servers: int,
176+
# output_dir: str,
177+
# server_labels: List[str],
178+
# ):
179+
# if not all_clients_results or not all_clients_results[0]:
180+
# raise ValueError("No data available to export.")
181+
#
182+
# if server_labels[0] is None:
183+
# server_labels = [f"server_{i + 1}" for i in range(num_servers)]
184+
#
185+
# # header
186+
# header_cells = input_features + [" - "]
187+
# for output_metric in output_metrics:
188+
# header_cells += [output_metric] + [" - "] * (len(server_labels) - 1)
189+
# header_row = "| " + " | ".join(map(str, header_cells)) + " |"
190+
#
191+
# # sub header
192+
# sub_header_cells = [" - "] * (len(input_features) + 1) + server_labels * len(
193+
# output_metrics
194+
# )
195+
# sub_header_row = "| " + " | ".join(map(str, sub_header_cells)) + " |"
196+
#
197+
# separator_row = "| " + " | ".join(["---"] * len(header_cells)) + " |"
198+
# lines = [header_row, sub_header_row, separator_row]
199+
#
200+
# for client_idx in range(num_clients):
201+
# #
202+
# row_values = []
203+
#
204+
# all_server_metrics = []
205+
# for server_idx in range(num_servers):
206+
# server_metrics = []
207+
# idx = client_idx + server_idx * num_clients
208+
# row_results = all_clients_results[idx]
209+
# if server_idx == 0:
210+
# for feature in input_features:
211+
# row_values.append(f"{row_results[0][feature]:.2f}")
212+
# row_values.append("-")
213+
# for metric in output_metrics:
214+
# server_metrics.append(avg_std_strf(metric, row_results, precision=2))
215+
# all_server_metrics.append(server_metrics)
216+
#
217+
# for i in range(len(output_metrics)):
218+
# for j in range(num_servers):
219+
# row_values.append(all_server_metrics[j][i])
220+
# lines.append("| " + " | ".join(row_values) + " |")
221+
#
222+
# with open(os.path.join(output_dir, TABLE_NAME), mode="w", encoding="utf-8") as f:
223+
# f.write("\n".join(lines))

ai_infra_bench/sgl/cmp_bench.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
maybe_create_labels,
2020
maybe_warmup,
2121
run_cmd,
22+
stop_server_process,
2223
wait_for_server,
2324
)
2425

@@ -91,9 +92,7 @@ def cmp_bench(
9192
output_dir=output_dir,
9293
)
9394
)
94-
95-
if server_process:
96-
server_process.terminate()
95+
stop_server_process(server_process)
9796

9897
if not disable_csv:
9998
gen_export_csv(

ai_infra_bench/utils.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -287,3 +287,53 @@ def kill_process_tree(parent_pid, include_parent: bool = True, skip_pid: int = N
287287
itself.send_signal(signal.SIGQUIT)
288288
except psutil.NoSuchProcess:
289289
pass
290+
291+
292+
def stop_server_process(process, timeout=30, cooldown_period=3):
293+
"""
294+
Stops a server sub-process safely by attempting a graceful shutdown
295+
followed by a forced kill if necessary.
296+
297+
Args:
298+
process (subprocess.Popen): The process object to stop.
299+
timeout (int): Seconds to wait for a graceful shutdown.
300+
cooldown_period (int): Extra seconds to wait for VRAM and Port cleanup.
301+
"""
302+
if process is None:
303+
logger.warning("No process found to terminate.")
304+
return
305+
306+
# Check if the process is already dead
307+
if process.poll() is not None:
308+
logger.info(
309+
f"Process (PID: {process.pid}) has already exited with code: {process.returncode}"
310+
)
311+
return
312+
313+
try:
314+
logger.info(f"Sending SIGTERM to process {process.pid} (Graceful Shutdown)...")
315+
# 1. Start graceful shutdown
316+
process.terminate()
317+
318+
try:
319+
# 2. Block and wait for the process to die
320+
process.wait(timeout=timeout)
321+
logger.info("Server exited gracefully.")
322+
except subprocess.TimeoutExpired:
323+
# 3. If it takes too long, force kill it
324+
logger.warning(
325+
f"Server did not exit within {timeout}s. Sending SIGKILL (Force Kill)..."
326+
)
327+
process.kill()
328+
process.wait() # Ensure the process is removed from the OS process table
329+
logger.info("Server was forcibly killed.")
330+
331+
except Exception as e:
332+
logger.error(f"An error occurred while stopping the server: {e}")
333+
334+
# 4. Critical: Wait for hardware/network cleanup
335+
if cooldown_period > 0:
336+
logger.info(
337+
f"Waiting {cooldown_period}s for GPU VRAM and TCP ports to be fully released..."
338+
)
339+
time.sleep(cooldown_period)

examples/cmp_bench.py

Lines changed: 27 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,27 @@
99
host = "127.0.0.1"
1010
port = "8888"
1111
tp_size = 1
12-
model_path = os.environ["QWEN3_32B_FP8"]
12+
model_path = "Qwen/Qwen3-0.6B"
13+
rate_list = [8]
1314
dataset_path = os.environ["SHAREGPT_DATAPATH"]
15+
input_features = [
16+
"random_input_len",
17+
"random_output_len",
18+
"request_rate",
19+
"max_concurrency",
20+
]
21+
22+
output_metrics = [
23+
"mean_ttft_ms",
24+
"p99_ttft_ms",
25+
"mean_tpot_ms",
26+
"p99_tpot_ms",
27+
"mean_itl_ms",
28+
"p99_itl_ms",
29+
"mean_e2e_latency_ms",
30+
"p99_e2e_latency_ms",
31+
"output_throughput",
32+
]
1433

1534

1635
####################################
@@ -23,17 +42,16 @@
2342
--host {host}
2443
--port {port}
2544
--disable-radix-cache
26-
--kv-cache-dtype fp8_e4m3
2745
"""
2846

2947
server_cmds: List[str] = [
3048
server_template.format(
3149
model_path=model_path, tp_size=tp_size, host=host, port=port
3250
),
3351
server_template.format(model_path=model_path, tp_size=tp_size, host=host, port=port)
34-
+ " --tool-call-parser qwen25",
52+
+ " --tool-call-parser qwen",
3553
]
36-
server_labels = ["Qwen3-32B-FP8", "QWEN3-32B-FP8-Without-tool"]
54+
server_labels = ["Qwen3-06B", "QWEN3-06B-With-Tool-Call-Parser"]
3755

3856
##########################
3957
# Constructing client_cmds
@@ -61,27 +79,12 @@
6179
output_len=output_len,
6280
request_rate=rate,
6381
dataset_path=dataset_path,
64-
num_prompt=rate * 10,
82+
num_prompt=min(max(rate * 10, 80), 250),
6583
)
66-
for rate in range(4, 12 + 1, 2)
84+
for rate in rate_list
6785
]
6886

6987
#####################
70-
input_features = [
71-
"request_rate",
72-
]
73-
74-
output_metrics = [
75-
"mean_ttft_ms",
76-
"p99_ttft_ms",
77-
"mean_tpot_ms",
78-
"p99_tpot_ms",
79-
"mean_itl_ms",
80-
"p99_itl_ms",
81-
"mean_e2e_latency_ms",
82-
"p99_e2e_latency_ms",
83-
"output_throughput",
84-
]
8588

8689
if __name__ == "__main__":
8790
cmp_bench(
@@ -92,7 +95,8 @@
9295
server_labels=server_labels,
9396
host=host,
9497
port=port,
95-
output_dir="tool_cmp_bench_output",
96-
n=3,
98+
n=1,
9799
only_last=True,
100+
output_dir="tool_cmp_bench_output",
101+
disable_warmup=True,
98102
)

0 commit comments

Comments
 (0)