Skip to content

Commit 5adb9ba

Browse files
Add Spec dec Bench example (#474)
## What does this PR do? **Type of change:** New feature <!-- Use one of the following: Bug fix, new feature, new example, new tests, documentation. --> **Overview:** Specdec bench example ## Usage <!-- You can potentially add a usage example below. --> ```python # Add a code snippet demonstrating how to use this ``` ## Testing <!-- Mention how have you tested your change if applicable. --> ## Before your PR is "*Ready for review*" <!-- If you haven't finished some of the above items you can still open `Draft` PR. --> - **Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/CONTRIBUTING.md)** and your commits are signed. - **Is this change backward compatible?**: Yes <!--- If No, explain why. --> - **Did you write any new necessary tests?**: No - **Did you add or update any necessary documentation?**: Yes - **Did you update [Changelog](https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/CHANGELOG.rst)?**: Yes <!--- Only for new features, API changes, critical bug fixes or bw breaking changes. --> ## Additional Information <!-- E.g. related issue. --> --------- Signed-off-by: Izzy Putterman <[email protected]> Signed-off-by: Keval Morabia <[email protected]> Co-authored-by: Keval Morabia <[email protected]>
1 parent ae915ea commit 5adb9ba

File tree

26 files changed

+1458
-0
lines changed

26 files changed

+1458
-0
lines changed

.github/CODEOWNERS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ modelopt/torch/utils @NVIDIA/modelopt-torch-utils-codeowners
4949
/examples/nemo_run @NVIDIA/modelopt-examples-megatron-codeowners
5050
/examples/onnx_ptq @NVIDIA/modelopt-onnx-codeowners
5151
/examples/pruning @NVIDIA/modelopt-torch-nas-prune-codeowners
52+
/examples/specdec_bench @NVIDIA/modelopt-torch-speculative-codeowners
5253
/examples/speculative_decoding @NVIDIA/modelopt-torch-speculative-codeowners
5354
/examples/vlm_ptq @NVIDIA/modelopt-examples-vlm-codeowners
5455
/examples/vllm_serve @NVIDIA/modelopt-examples-llm_ptq-codeowners

CHANGELOG.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,10 @@ Model Optimizer Changelog (Linux)
88

99
- Fix a bug in FastNAS pruning (computer vision models) where the model parameters were sorted twice messing up the ordering.
1010

11+
**New Features**
12+
13+
- Add ``specdec_bench`` example to benchmark speculative decoding performance. See `examples/specdec_bench/README.md <https://github.com/NVIDIA/TensorRT-Model-Optimizer/tree/main/examples/specdec_bench#speculative-decoding-benchmark>`_ for more details.
14+
1115
0.39 (2025-11-14)
1216
^^^^^^^^^^^^^^^^^
1317

examples/specdec_bench/README.md

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
# Speculative Decoding (SpecDec) Bench
2+
3+
## Installation
4+
5+
This benchmark is meant to be a lightweight layer ontop of an existing vLLM/SGLang/TRTLLM installation. For example, no install
6+
is required if one is running in the following dockers: `vllm/vllm-openai:v0.11.0` (vLLM), `lmsysorg/sglang:v0.5.4.post2` (SGLang), or
7+
`nvcr.io/nvidia/tensorrt-llm/release:1.2.0rc1` (TRT-LLM).
8+
9+
Next
10+
11+
```bash
12+
cd examples/specdec_bench
13+
```
14+
15+
## Purpose
16+
17+
Collect relevant metrics on acceptance rate, timing, and outputs for Speculative Decoding methods.
18+
Acceptance rate refers to the number of tokens generated on every iteration. For a standard Autoregressive LLM, this number
19+
is just 1.
20+
21+
## Getting Started
22+
23+
A basic example run script is provided which benchmarks MTBench (a standard 160 prompts spanning 8 categories).
24+
MTBench is available [here](https://huggingface.co/datasets/HuggingFaceH4/mt_bench_prompts)
25+
26+
### Running MTBench on GPT OSS + Eagle3
27+
28+
Download `nvidia/gpt-oss-120b-Eagle3` to a local directory `/path/to/eagle`.
29+
30+
```bash
31+
python3 run.py --model_dir openai/gpt-oss-120b --tokenizer openai/gpt-oss-120b --draft_model_dir /path/to/eagle --mtbench question.jsonl --tp_size 1 --ep_size 1 --draft_length 3 --output_length 4096 --num_requests 80 --engine TRTLLM --concurrency 1
32+
33+
```
34+
35+
### Running Random ids on GPT OSS + Eagle3
36+
37+
Download `nvidia/gpt-oss-120b-Eagle3` to a local directory `/path/to/eagle`.
38+
39+
```bash
40+
python3 run.py --model_dir openai/gpt-oss-120b --tokenizer openai/gpt-oss-120b --draft_model_dir /path/to/eagle --random_isl 1024 --tp_size 1 --ep_size 1 --draft_length 3 --output_length 4096 --num_requests 40 --engine TRTLLM --concurrency 1
41+
42+
```
43+
44+
## Notes
45+
46+
The goal of this benchmark is to provide an easy way to configure, run, and compare speculative implementations across frameworks in an apples-to-apples method.
47+
This benchmark sends request in a single-threaded fashion, so running large concurrency (>256) may result in python async scheduling delays and skew metrics.
48+
If larger concurrency is needed, it is recommended to fully deploy the model using `vllm serve`, `python -m sglang.launch_server`, or `trtllm-serve` (for vLLM, SGlang, or TRTLLM respectively) and
49+
use a more robust benchmarking client like NVIDIA AI Perf.

examples/specdec_bench/run.py

Lines changed: 196 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,196 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import argparse
17+
import asyncio
18+
19+
import yaml
20+
from specdec_bench import datasets, metrics, models, runners
21+
from specdec_bench.utils import decode_chat, encode_chat, get_tokenizer, postprocess_base
22+
23+
engines_available = {
24+
"TRTLLM": models.TRTLLMPYTModel,
25+
"VLLM": models.VLLMModel,
26+
"SGLANG": models.SGLANGModel,
27+
}
28+
29+
30+
async def run_loop(runner, dataset, tokenizer, output_length, postprocess, concurrency=10):
31+
"""
32+
Async version of run_loop with concurrency control using a semaphore.
33+
34+
Args:
35+
runner: The model runner instance
36+
dataset: The dataset containing requests
37+
tokenizer: The tokenizer instance
38+
output_length: Maximum output length
39+
concurrency: Maximum number of concurrent requests (default: 10)
40+
"""
41+
semaphore = asyncio.Semaphore(concurrency)
42+
max_length = output_length
43+
end_id = tokenizer.eos_token_id
44+
45+
async def process_single_request(request, i):
46+
"""Process a single request with all its conversation turns."""
47+
async with semaphore:
48+
messages = []
49+
if request.system_prompt is not None:
50+
messages.append({"role": "system", "content": request.system_prompt})
51+
52+
for question in request.turns:
53+
messages.append({"role": "user", "content": question})
54+
entry_encoded = encode_chat(tokenizer, messages)
55+
56+
# Run the async runner.run directly
57+
output_tokens = await runner.run(entry_encoded, max_length, end_id, i)
58+
output_text = decode_chat(tokenizer, output_tokens["output_ids"][0])
59+
output_text = postprocess(output_text)
60+
messages.append({"role": "assistant", "content": output_text})
61+
62+
return messages
63+
64+
tasks = [process_single_request(request, i) for i, request in enumerate(dataset.data)]
65+
text_outputs = await asyncio.gather(*tasks, return_exceptions=True)
66+
67+
# Check for any exceptions and handle them
68+
for i, result in enumerate(text_outputs):
69+
if isinstance(result, Exception):
70+
print(f"Error processing request {i}: {result}")
71+
raise result
72+
73+
runner.process_metrics_final(text_outputs)
74+
return text_outputs
75+
76+
77+
def run_simple(args):
78+
tokenizer = get_tokenizer(args.tokenizer)
79+
dataset_kwargs = args.runtime_params.get("dataset_kwargs", {})
80+
if args.mtbench is not None:
81+
dataset = datasets.MTBench(args.mtbench, args.num_requests, **dataset_kwargs)
82+
elif args.random_isl is not None:
83+
dataset = datasets.RandomToken(
84+
tokenizer, args.random_isl, args.num_requests, **dataset_kwargs
85+
)
86+
engine_args = args.runtime_params.get("engine_args", {})
87+
sampling_kwargs = args.runtime_params.get("sampling_kwargs", {"temperature": 0})
88+
model_class = engines_available[args.engine]
89+
model = model_class(
90+
args.model_dir,
91+
max_concurrent_requests=args.concurrency,
92+
sampling_kwargs=sampling_kwargs,
93+
speculative_algorithm=args.speculative_algorithm,
94+
draft_model_dir=args.draft_model_dir,
95+
speculative_num_steps=args.draft_length,
96+
tensor_parallel_size=args.tp_size,
97+
moe_expert_parallel_size=args.ep_size,
98+
**engine_args,
99+
)
100+
101+
metrics_list = [metrics.Timing(args.tp_size)]
102+
if args.aa_timing:
103+
metrics_list.append(metrics.AATiming(tokenizer))
104+
if args.mtbench is not None:
105+
metrics_list.insert(0, metrics.MTBench())
106+
else:
107+
metrics_list.insert(0, metrics.AcceptanceRate())
108+
runner = runners.SimpleRunner(model, metrics=metrics_list)
109+
110+
postprocess = postprocess_base
111+
112+
asyncio.run(
113+
run_loop(runner, dataset, tokenizer, args.output_length, postprocess, args.concurrency)
114+
)
115+
116+
runner.clear_metrics()
117+
118+
119+
if __name__ == "__main__":
120+
parser = argparse.ArgumentParser()
121+
parser.add_argument(
122+
"--tokenizer", type=str, required=True, help="Path to the tokenizer directory"
123+
)
124+
parser.add_argument(
125+
"--mtbench", type=str, required=False, default=None, help="Path to the mtbench dataset"
126+
)
127+
parser.add_argument(
128+
"--random_isl",
129+
type=int,
130+
required=False,
131+
default=None,
132+
help="How many tokens random input should be.",
133+
)
134+
parser.add_argument("--num_requests", type=int, required=True, help="Number of requests to run")
135+
parser.add_argument(
136+
"--engine",
137+
type=str,
138+
required=False,
139+
default="TRTLLM",
140+
choices=list(engines_available.keys()),
141+
help="Engine to use",
142+
)
143+
parser.add_argument(
144+
"--speculative_algorithm",
145+
type=str,
146+
required=False,
147+
default="EAGLE3",
148+
choices=["EAGLE3", "EAGLE", "DRAFT_TARGET", "NGRAM", "MTP", "NONE"],
149+
help="Speculative algorithm to use",
150+
)
151+
parser.add_argument("--model_dir", type=str, required=True, help="Path to the model directory")
152+
parser.add_argument(
153+
"--draft_model_dir",
154+
type=str,
155+
required=False,
156+
default=None,
157+
help="Path to the draft model directory",
158+
)
159+
parser.add_argument(
160+
"--runtime_params",
161+
type=str,
162+
required=False,
163+
default=None,
164+
help="Path to the runtime params yaml file",
165+
)
166+
parser.add_argument(
167+
"--output_length", type=int, required=False, default=4096, help="Output length"
168+
)
169+
parser.add_argument("--draft_length", type=int, required=False, default=3, help="Draft length")
170+
parser.add_argument(
171+
"--tp_size", type=int, required=False, default=4, help="Tensor parallel size"
172+
)
173+
parser.add_argument(
174+
"--ep_size", type=int, required=False, default=2, help="Expert parallel size"
175+
)
176+
parser.add_argument(
177+
"--concurrency",
178+
type=int,
179+
required=False,
180+
default=1,
181+
help="Maximum number of concurrent requests",
182+
)
183+
parser.add_argument("--aa_timing", action="store_true", help="Enable AA timing metric")
184+
args = parser.parse_args()
185+
186+
if args.runtime_params is not None:
187+
with open(args.runtime_params) as f:
188+
args.runtime_params = yaml.safe_load(f)
189+
else:
190+
args.runtime_params = {}
191+
192+
assert args.mtbench is not None or args.random_isl is not None, (
193+
"Either mtbench or random_isl must be provided"
194+
)
195+
196+
run_simple(args)
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
from .base import Dataset
17+
from .base_hf import OpenMathInstructv2, OpenOrca, UltraChat
18+
from .mtbench import MTBench
19+
from .random_token import RandomToken
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
from dataclasses import dataclass, field
17+
from typing import Any
18+
19+
20+
@dataclass
21+
class Request:
22+
system_prompt: str | None = None
23+
turns: list[str] = field(default_factory=list)
24+
mm_content: Any | None = None # TODO
25+
26+
# not to be set by user
27+
output_turn_ids = None
28+
output_turn_text: list[str] = field(default_factory=list)
29+
30+
31+
class Dataset:
32+
def __init__(self, path, **kwargs):
33+
self.data: list[Request] = []
34+
raise NotImplementedError
35+
36+
def _preprocess(self):
37+
raise NotImplementedError

0 commit comments

Comments
 (0)