Skip to content

Commit db5b876

Browse files
authored
[None][feat] support for more accurate AR calculation (NVIDIA#9323)
Signed-off-by: binghanc <[email protected]>
1 parent f8dd494 commit db5b876

File tree

1 file changed

+88
-29
lines changed

1 file changed

+88
-29
lines changed

tensorrt_llm/serve/scripts/benchmark_serving.py

Lines changed: 88 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,13 @@ class BenchmarkMetrics:
8080
std_e2el_ms: float
8181
percentiles_e2el_ms: list[tuple[float, float]]
8282
tput_user: list[float]
83-
avg_decoded_tokens_per_iter: float
83+
# Statistics for avg_decoded_tokens_per_iter across all requests
84+
mean_avg_decoded_tokens_per_iter: float
85+
min_avg_decoded_tokens_per_iter: float
86+
max_avg_decoded_tokens_per_iter: float
87+
median_avg_decoded_tokens_per_iter: float
88+
std_avg_decoded_tokens_per_iter: float
89+
percentiles_avg_decoded_tokens_per_iter: list[tuple[float, float]]
8490

8591

8692
async def get_request(
@@ -144,7 +150,7 @@ def calculate_metrics(
144150
ttfts: list[float] = []
145151
e2els: list[float] = []
146152
tput_user: list[float] = []
147-
latest_avg_decoded_tokens_per_iter: float = 0.0
153+
avg_decoded_tokens_per_iter_list: list[float] = []
148154
error_counts: dict[str, int] = {}
149155
for i in range(len(outputs)):
150156
if outputs[i].exception_type:
@@ -177,11 +183,11 @@ def calculate_metrics(
177183
tput_user.append(output_len / (outputs[i].latency))
178184
completed += 1
179185

180-
# Track the latest avg_decoded_tokens_per_iter if available
186+
# Collect avg_decoded_tokens_per_iter for all requests
181187
if hasattr(outputs[i], 'avg_decoded_tokens_per_iter'
182188
) and outputs[i].avg_decoded_tokens_per_iter is not None:
183-
latest_avg_decoded_tokens_per_iter = outputs[
184-
i].avg_decoded_tokens_per_iter
189+
avg_decoded_tokens_per_iter_list.append(
190+
outputs[i].avg_decoded_tokens_per_iter)
185191
else:
186192
actual_output_lens.append(0)
187193

@@ -247,7 +253,20 @@ def calculate_metrics(
247253
percentiles_e2el_ms=[(p, np.percentile(e2els or 0, p) * 1000)
248254
for p in selected_percentiles],
249255
tput_user=np.mean(tput_user or 0),
250-
avg_decoded_tokens_per_iter=latest_avg_decoded_tokens_per_iter,
256+
mean_avg_decoded_tokens_per_iter=np.mean(
257+
avg_decoded_tokens_per_iter_list or 0),
258+
min_avg_decoded_tokens_per_iter=np.min(avg_decoded_tokens_per_iter_list)
259+
if avg_decoded_tokens_per_iter_list else 0.0,
260+
max_avg_decoded_tokens_per_iter=np.max(avg_decoded_tokens_per_iter_list)
261+
if avg_decoded_tokens_per_iter_list else 0.0,
262+
median_avg_decoded_tokens_per_iter=np.median(
263+
avg_decoded_tokens_per_iter_list or 0),
264+
std_avg_decoded_tokens_per_iter=np.std(avg_decoded_tokens_per_iter_list
265+
or 0),
266+
percentiles_avg_decoded_tokens_per_iter=[
267+
(p, np.percentile(avg_decoded_tokens_per_iter_list or 0, p))
268+
for p in selected_percentiles
269+
],
251270
)
252271
return metrics, actual_output_lens
253272

@@ -466,10 +485,6 @@ async def limited_request_func(request_func_input, streaming, pbar,
466485
print("{:<40} {:<10.2f}".format("User throughput (tok/s):",
467486
metrics.tput_user))
468487

469-
# Print last avg_decoded_tokens_per_iter value if available
470-
if metrics.avg_decoded_tokens_per_iter > 0.0:
471-
print("{:<40} {:<10.2f}".format("Avg Decoded Tokens per Iter:",
472-
metrics.avg_decoded_tokens_per_iter))
473488
if len(outputs) - metrics.completed > 0:
474489
print(
475490
f"=======================!FAILED REQUESTS!=======================")
@@ -488,7 +503,17 @@ async def limited_request_func(request_func_input, streaming, pbar,
488503
"output_throughput": metrics.output_throughput,
489504
"total_token_throughput": metrics.total_token_throughput,
490505
"user_throughput": metrics.tput_user,
491-
"avg_decoded_tokens_per_iter": metrics.avg_decoded_tokens_per_iter,
506+
"avg_decoded_tokens_per_iter": {
507+
"mean": metrics.mean_avg_decoded_tokens_per_iter,
508+
"min": metrics.min_avg_decoded_tokens_per_iter,
509+
"max": metrics.max_avg_decoded_tokens_per_iter,
510+
"median": metrics.median_avg_decoded_tokens_per_iter,
511+
"std": metrics.std_avg_decoded_tokens_per_iter,
512+
"percentiles": {
513+
f"p{p}": v
514+
for p, v in metrics.percentiles_avg_decoded_tokens_per_iter
515+
}
516+
},
492517
"input_lens": [output.prompt_len for output in outputs],
493518
"output_lens": actual_output_lens,
494519
"ttfts": [output.ttft for output in outputs],
@@ -504,30 +529,64 @@ def process_one_metric(
504529
metric_name: str,
505530
# E.g., "Time to First Token"
506531
metric_header: str,
532+
# E.g., "ms" or "" for no unit
533+
unit_suffix: str = "ms",
507534
):
508-
# This function prints and adds statistics of the specified
509-
# metric.
510-
if metric_attribute_name not in selected_percentile_metrics:
535+
# This function prints and adds statistics of the specified metric.
536+
# Skip if not in selected metrics (except avg_decoded_tokens_per_iter which has its own condition)
537+
if (metric_attribute_name not in selected_percentile_metrics
538+
and metric_attribute_name != "avg_decoded_tokens_per_iter"):
511539
return
540+
541+
# Build attribute suffix (e.g., "_ms" or "")
542+
attr_suffix = f"_{unit_suffix}" if unit_suffix else ""
543+
# Build display unit (e.g., " (ms)" or "")
544+
display_unit = f" ({unit_suffix})" if unit_suffix else ""
545+
512546
print("{s:{c}^{n}}".format(s=metric_header, n=50, c='-'))
513547
print("{:<40} {:<10.2f}".format(
514-
f"Mean {metric_name} (ms):",
515-
getattr(metrics, f"mean_{metric_attribute_name}_ms")))
548+
f"Mean {metric_name}{display_unit}:",
549+
getattr(metrics, f"mean_{metric_attribute_name}{attr_suffix}")))
516550
print("{:<40} {:<10.2f}".format(
517-
f"Median {metric_name} (ms):",
518-
getattr(metrics, f"median_{metric_attribute_name}_ms")))
519-
result[f"mean_{metric_attribute_name}_ms"] = getattr(
520-
metrics, f"mean_{metric_attribute_name}_ms")
521-
result[f"median_{metric_attribute_name}_ms"] = getattr(
522-
metrics, f"median_{metric_attribute_name}_ms")
523-
result[f"std_{metric_attribute_name}_ms"] = getattr(
524-
metrics, f"std_{metric_attribute_name}_ms")
525-
for p, value in getattr(metrics,
526-
f"percentiles_{metric_attribute_name}_ms"):
551+
f"Median {metric_name}{display_unit}:",
552+
getattr(metrics, f"median_{metric_attribute_name}{attr_suffix}")))
553+
if hasattr(metrics, f"std_{metric_attribute_name}{attr_suffix}"):
554+
print("{:<40} {:<10.2f}".format(
555+
f"Std Dev {metric_name}{display_unit}:",
556+
getattr(metrics, f"std_{metric_attribute_name}{attr_suffix}")))
557+
result[f"std_{metric_attribute_name}{attr_suffix}"] = getattr(
558+
metrics, f"std_{metric_attribute_name}{attr_suffix}")
559+
if hasattr(metrics, f"min_{metric_attribute_name}{attr_suffix}"):
560+
print("{:<40} {:<10.2f}".format(
561+
f"Min {metric_name}{display_unit}:",
562+
getattr(metrics, f"min_{metric_attribute_name}{attr_suffix}")))
563+
result[f"min_{metric_attribute_name}{attr_suffix}"] = getattr(
564+
metrics, f"min_{metric_attribute_name}{attr_suffix}")
565+
if hasattr(metrics, f"max_{metric_attribute_name}{attr_suffix}"):
566+
print("{:<40} {:<10.2f}".format(
567+
f"Max {metric_name}{display_unit}:",
568+
getattr(metrics, f"max_{metric_attribute_name}{attr_suffix}")))
569+
result[f"max_{metric_attribute_name}{attr_suffix}"] = getattr(
570+
metrics, f"max_{metric_attribute_name}{attr_suffix}")
571+
572+
result[f"mean_{metric_attribute_name}{attr_suffix}"] = getattr(
573+
metrics, f"mean_{metric_attribute_name}{attr_suffix}")
574+
result[f"median_{metric_attribute_name}{attr_suffix}"] = getattr(
575+
metrics, f"median_{metric_attribute_name}{attr_suffix}")
576+
577+
for p, value in getattr(
578+
metrics, f"percentiles_{metric_attribute_name}{attr_suffix}"):
527579
p_word = str(int(p)) if int(p) == p else str(p)
528-
print("{:<40} {:<10.2f}".format(f"P{p_word} {metric_name} (ms):",
529-
value))
530-
result[f"p{p_word}_{metric_attribute_name}_ms"] = value
580+
print("{:<40} {:<10.2f}".format(
581+
f"P{p_word} {metric_name}{display_unit}:", value))
582+
result[f"p{p_word}_{metric_attribute_name}{attr_suffix}"] = value
583+
584+
# Print avg_decoded_tokens_per_iter statistics if available
585+
if metrics.mean_avg_decoded_tokens_per_iter > 0.0:
586+
process_one_metric("avg_decoded_tokens_per_iter",
587+
"Avg Decoded Tokens per Iter",
588+
"Avg Decoded Tokens per Iter",
589+
unit_suffix="")
531590

532591
process_one_metric("ttft", "TTFT", "Time to First Token")
533592
process_one_metric("tpot", "TPOT",

0 commit comments

Comments
 (0)