Skip to content

Commit faa80e7

Browse files
[None][feat] Auto download speculative models from HF for pytorch backend, add speculative_model field alias (#10099)
Signed-off-by: Anish Shanbhag <[email protected]>
1 parent 62050b2 commit faa80e7

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+430
-277
lines changed

docs/source/features/speculative-decoding.md

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,13 @@ Draft/target is the simplest form of speculative decoding. In this approach, an
3737
```python
3838
from tensorrt_llm.llmapi import DraftTargetDecodingConfig
3939

40+
# Option 1: Use a HuggingFace Hub model ID (auto-downloaded)
4041
speculative_config = DraftTargetDecodingConfig(
41-
max_draft_len=3, speculative_model_dir="/path/to/draft_model")
42+
max_draft_len=3, speculative_model="yuhuili/EAGLE3-LLaMA3.1-Instruct-8B")
43+
44+
# Option 2: Use a local path
45+
# speculative_config = DraftTargetDecodingConfig(
46+
# max_draft_len=3, speculative_model="/path/to/draft_model")
4247

4348
llm = LLM("/path/to/target_model", speculative_config=speculative_config, disable_overlap_scheduler=True)
4449
```
@@ -51,18 +56,23 @@ TRT-LLM supports a modified version of the algorithm presented in the paper: tre
5156
The following draft model checkpoints can be used for EAGLE 3:
5257
* Llama 3 variants: [use the checkpoints from the authors of the original EAGLE 3 paper](https://huggingface.co/yuhuili).
5358
* Llama 4 Maverick: [use the checkpoint from the NVIDIA HuggingFace repository](https://huggingface.co/nvidia/Llama-4-Maverick-17B-128E-Eagle3).
59+
* Other models, including `gpt-oss-120b` and `Qwen3`: check out the [Speculative Decoding Modules](https://huggingface.co/collections/nvidia/speculative-decoding-modules) collection from NVIDIA.
5460

5561
```python
5662
from tensorrt_llm.llmapi import EagleDecodingConfig
5763

5864
# Enable to use the faster one-model implementation for Llama 4.
5965
eagle3_one_model = False
66+
model = "meta-llama/Llama-3.1-8B-Instruct"
67+
speculative_model = "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B"
6068

6169
speculative_config = EagleDecodingConfig(
62-
max_draft_len=3, speculative_model_dir="/path/to/draft_model", eagle3_one_model=eagle3_one_model)
70+
max_draft_len=3,
71+
speculative_model=speculative_model,
72+
eagle3_one_model=eagle3_one_model)
6373

6474
# Only need to disable overlap scheduler if eagle3_one_model is False.
65-
llm = LLM("/path/to/target_model", speculative_config=speculative_config, disable_overlap_scheduler=True)
75+
llm = LLM(model, speculative_config=speculative_config, disable_overlap_scheduler=True)
6676
```
6777

6878
### NGram
@@ -137,14 +147,34 @@ Speculative decoding options must be specified via `--config config.yaml` for bo
137147

138148
The rest of the argument names/valid values are the same as in their corresponding configuration class described in the Quick Start section. For example, a YAML configuration could look like this:
139149

150+
```yaml
151+
# Using a HuggingFace Hub model ID (auto-downloaded)
152+
disable_overlap_scheduler: true
153+
speculative_config:
154+
decoding_type: Eagle
155+
max_draft_len: 4
156+
speculative_model: yuhuili/EAGLE3-LLaMA3.1-Instruct-8B
140157
```
158+
159+
```yaml
160+
# Or using a local path
141161
disable_overlap_scheduler: true
142162
speculative_config:
143163
decoding_type: Eagle
144164
max_draft_len: 4
145165
speculative_model: /path/to/draft/model
146166
```
147167
168+
```{note}
169+
The field name `speculative_model_dir` can also be used as an alias for `speculative_config.speculative_model`. For example:
170+
171+
speculative_config:
172+
decoding_type: Eagle
173+
max_draft_len: 4
174+
speculative_model_dir: /path/to/draft/model
175+
```
176+
177+
148178
## Developer Guide
149179
150180
This section describes the components of a speculative decoding algorithm. All of the interfaces are defined in [`_torch/speculative/interface.py`](https://github.com/NVIDIA/TensorRT-LLM/blob/main/tensorrt_llm/_torch/speculative/interface.py).

examples/llm-api/_tensorrt_engine/llm_eagle2_decoding.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,12 @@ def main():
2323
model = "lmsys/vicuna-7b-v1.3"
2424

2525
# The end user can customize the eagle decoding configuration by specifying the
26-
# speculative_model_dir, max_draft_len, num_eagle_layers, max_non_leaves_per_layer, eagle_choices
26+
# speculative_model, max_draft_len, num_eagle_layers, max_non_leaves_per_layer, eagle_choices
2727
# greedy_sampling,posterior_threshold, use_dynamic_tree and dynamic_tree_max_topK
2828
# with the EagleDecodingConfig class
2929

3030
speculative_config = EagleDecodingConfig(
31-
speculative_model_dir="yuhuili/EAGLE-Vicuna-7B-v1.3",
31+
speculative_model="yuhuili/EAGLE-Vicuna-7B-v1.3",
3232
max_draft_len=63,
3333
num_eagle_layers=4,
3434
max_non_leaves_per_layer=10,

examples/llm-api/_tensorrt_engine/llm_eagle_decoding.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,12 @@ def main():
2323
model = "lmsys/vicuna-7b-v1.3"
2424

2525
# The end user can customize the eagle decoding configuration by specifying the
26-
# speculative_model_dir, max_draft_len, num_eagle_layers, max_non_leaves_per_layer, eagle_choices
26+
# speculative_model, max_draft_len, num_eagle_layers, max_non_leaves_per_layer, eagle_choices
2727
# greedy_sampling,posterior_threshold, use_dynamic_tree and dynamic_tree_max_topK
2828
# with the EagleDecodingConfig class
2929

3030
speculative_config = EagleDecodingConfig(
31-
speculative_model_dir="yuhuili/EAGLE-Vicuna-7B-v1.3",
31+
speculative_model="yuhuili/EAGLE-Vicuna-7B-v1.3",
3232
max_draft_len=63,
3333
num_eagle_layers=4,
3434
max_non_leaves_per_layer=10,

examples/llm-api/_tensorrt_engine/llm_medusa_decoding.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,10 +48,10 @@ def run_medusa_decoding(use_modelopt_ckpt=False, model_dir=None):
4848
model = "lmsys/vicuna-7b-v1.3"
4949

5050
# The end user can customize the medusa decoding configuration by specifying the
51-
# speculative_model_dir, max_draft_len, medusa heads num and medusa choices
51+
# speculative_model, max_draft_len, medusa heads num and medusa choices
5252
# with the MedusaDecodingConfig class
5353
speculative_config = MedusaDecodingConfig(
54-
speculative_model_dir="FasterDecoding/medusa-vicuna-7b-v1.3",
54+
speculative_model="FasterDecoding/medusa-vicuna-7b-v1.3",
5555
max_draft_len=63,
5656
num_medusa_heads=4,
5757
medusa_choices=[[0], [0, 0], [1], [0, 1], [2], [0, 0, 0], [1, 0], [0, 2], [3], [0, 3], [4], [0, 4], [2, 0], \

examples/llm-api/llm_speculative_decoding.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def run_MTP(model: Optional[str] = None):
3535
def run_Eagle3():
3636
spec_config = EagleDecodingConfig(
3737
max_draft_len=3,
38-
speculative_model_dir="yuhuili/EAGLE3-LLaMA3.1-Instruct-8B",
38+
speculative_model="yuhuili/EAGLE3-LLaMA3.1-Instruct-8B",
3939
eagle3_one_model=True)
4040

4141
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.8)

examples/llm-api/quickstart_advanced.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -220,11 +220,11 @@ def setup_llm(args, **kwargs):
220220
relaxed_topk=args.relaxed_topk,
221221
relaxed_delta=args.relaxed_delta,
222222
mtp_eagle_one_model=args.use_one_model,
223-
speculative_model_dir=args.model_dir)
223+
speculative_model=args.model_dir)
224224
elif spec_decode_algo == "EAGLE3":
225225
spec_config = EagleDecodingConfig(
226226
max_draft_len=args.spec_decode_max_draft_len,
227-
speculative_model_dir=args.draft_model_dir,
227+
speculative_model=args.draft_model_dir,
228228
eagle3_one_model=args.use_one_model,
229229
eagle_choices=args.eagle_choices,
230230
use_dynamic_tree=args.use_dynamic_tree,
@@ -234,7 +234,7 @@ def setup_llm(args, **kwargs):
234234
elif spec_decode_algo == "DRAFT_TARGET":
235235
spec_config = DraftTargetDecodingConfig(
236236
max_draft_len=args.spec_decode_max_draft_len,
237-
speculative_model_dir=args.draft_model_dir)
237+
speculative_model=args.draft_model_dir)
238238
elif spec_decode_algo == "NGRAM":
239239
spec_config = NGramDecodingConfig(
240240
max_draft_len=args.spec_decode_max_draft_len,

examples/models/core/qwen/README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -841,8 +841,8 @@ Qwen3 now supports Eagle3 (Speculative Decoding with Eagle3). To enable Eagle3 o
841841
Set the decoding type to "Eagle" to enable Eagle3 speculative decoding.
842842
- `speculative_config.max_draft_len: 3`
843843
Set the maximum number of draft tokens generated per step (this value can be adjusted as needed).
844-
- `speculative_config.speculative_model_dir: <EAGLE3_DRAFT_MODEL_PATH>`
845-
Specify the path to the Eagle3 draft model (ensure the corresponding draft model weights are prepared).
844+
- `speculative_config.speculative_model: <HUGGINGFACE ID / LOCAL PATH>`
845+
Specify the Eagle3 draft model either as a Huggingface model ID or a local path. You can find ready-to-use Eagle3 draft models at https://huggingface.co/collections/nvidia/speculative-decoding-modules.
846846

847847
Currently, there are some limitations when enabling Eagle3:
848848

@@ -857,7 +857,7 @@ enable_attention_dp: false
857857
speculative_config:
858858
decoding_type: Eagle
859859
max_draft_len: 3
860-
speculative_model_dir: <EAGLE3_DRAFT_MODEL_PATH>
860+
speculative_model: <HUGGINGFACE ID / LOCAL PATH>
861861
kv_cache_config:
862862
enable_block_reuse: false
863863
" >> ${path_config}

tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -923,7 +923,7 @@ def create_draft_model_engine_maybe(
923923
drafting_loop_wrapper = None
924924

925925
draft_model_engine = PyTorchModelEngine(
926-
model_path=draft_spec_config.speculative_model_dir,
926+
model_path=draft_spec_config.speculative_model,
927927
llm_args=draft_llm_args,
928928
mapping=dist_mapping,
929929
attn_runtime_features=attn_runtime_features,

tensorrt_llm/_torch/models/modeling_speculative.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -887,7 +887,7 @@ def __init__(self, model: TModel, model_config: ModelConfig[TConfig]):
887887
from tensorrt_llm._torch.models.checkpoints.mistral.config_loader import \
888888
MistralConfigLoader
889889
self.draft_config = MistralConfigLoader().load(
890-
spec_config.speculative_model_dir,
890+
spec_config.speculative_model,
891891
mapping=model_config.mapping,
892892
moe_backend=model_config.moe_backend,
893893
moe_max_num_tokens=model_config.moe_max_num_tokens,
@@ -898,7 +898,7 @@ def __init__(self, model: TModel, model_config: ModelConfig[TConfig]):
898898
self.draft_config.extra_attrs = model_config.extra_attrs
899899
elif spec_config.eagle3_model_arch == "llama3":
900900
self.draft_config = ModelConfig.from_pretrained(
901-
model_config.spec_config.speculative_model_dir,
901+
model_config.spec_config.speculative_model,
902902
trust_remote_code=True,
903903
attn_backend=model_config.attn_backend,
904904
moe_backend=model_config.moe_backend,

tensorrt_llm/_torch/pyexecutor/model_loader.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -278,7 +278,7 @@ def init_meta_tensor(t: torch.Tensor):
278278
if self.spec_config is not None and self.spec_config.spec_dec_mode.need_load_draft_weights(
279279
):
280280
weights = checkpoint_loader.load_weights(
281-
self.spec_config.speculative_model_dir,
281+
self.spec_config.speculative_model,
282282
mapping=self.mapping)
283283

284284
draft_model_arch = model.draft_config.pretrained_config.architectures[

0 commit comments

Comments
 (0)