Skip to content
36 changes: 33 additions & 3 deletions docs/source/features/speculative-decoding.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,13 @@ Draft/target is the simplest form of speculative decoding. In this approach, an
```python
from tensorrt_llm.llmapi import DraftTargetDecodingConfig

# Option 1: Use a HuggingFace Hub model ID (auto-downloaded)
speculative_config = DraftTargetDecodingConfig(
max_draft_len=3, speculative_model_dir="/path/to/draft_model")
max_draft_len=3, speculative_model="yuhuili/EAGLE3-LLaMA3.1-Instruct-8B")

# Option 2: Use a local path
# speculative_config = DraftTargetDecodingConfig(
# max_draft_len=3, speculative_model="/path/to/draft_model")

llm = LLM("/path/to/target_model", speculative_config=speculative_config, disable_overlap_scheduler=True)
```
Expand All @@ -51,18 +56,23 @@ TRT-LLM supports a modified version of the algorithm presented in the paper: tre
The following draft model checkpoints can be used for EAGLE 3:
* Llama 3 variants: [use the checkpoints from the authors of the original EAGLE 3 paper](https://huggingface.co/yuhuili).
* Llama 4 Maverick: [use the checkpoint from the NVIDIA HuggingFace repository](https://huggingface.co/nvidia/Llama-4-Maverick-17B-128E-Eagle3).
* Other models, including `gpt-oss-120b` and `Qwen3`: check out the [Speculative Decoding Modules](https://huggingface.co/collections/nvidia/speculative-decoding-modules) collection from NVIDIA.

```python
from tensorrt_llm.llmapi import EagleDecodingConfig

# Enable to use the faster one-model implementation for Llama 4.
eagle3_one_model = False
model = "meta-llama/Llama-3.1-8B-Instruct"
speculative_model = "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B"

speculative_config = EagleDecodingConfig(
max_draft_len=3, speculative_model_dir="/path/to/draft_model", eagle3_one_model=eagle3_one_model)
max_draft_len=3,
speculative_model=speculative_model,
eagle3_one_model=eagle3_one_model)

# Only need to disable overlap scheduler if eagle3_one_model is False.
llm = LLM("/path/to/target_model", speculative_config=speculative_config, disable_overlap_scheduler=True)
llm = LLM(model, speculative_config=speculative_config, disable_overlap_scheduler=True)
```

### NGram
Expand Down Expand Up @@ -137,14 +147,34 @@ Speculative decoding options must be specified via `--config config.yaml` for bo

The rest of the argument names/valid values are the same as in their corresponding configuration class described in the Quick Start section. For example, a YAML configuration could look like this:

```yaml
# Using a HuggingFace Hub model ID (auto-downloaded)
disable_overlap_scheduler: true
speculative_config:
decoding_type: Eagle
max_draft_len: 4
speculative_model: yuhuili/EAGLE3-LLaMA3.1-Instruct-8B
```

```yaml
# Or using a local path
disable_overlap_scheduler: true
speculative_config:
decoding_type: Eagle
max_draft_len: 4
speculative_model: /path/to/draft/model
```

```{note}
The field name `speculative_model_dir` can also be used as an alias for `speculative_config.speculative_model`. For example:

speculative_config:
decoding_type: Eagle
max_draft_len: 4
speculative_model_dir: /path/to/draft/model
```


## Developer Guide

This section describes the components of a speculative decoding algorithm. All of the interfaces are defined in [`_torch/speculative/interface.py`](https://github.com/NVIDIA/TensorRT-LLM/blob/main/tensorrt_llm/_torch/speculative/interface.py).
Expand Down
4 changes: 2 additions & 2 deletions examples/llm-api/_tensorrt_engine/llm_eagle2_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,12 @@ def main():
model = "lmsys/vicuna-7b-v1.3"

# The end user can customize the eagle decoding configuration by specifying the
# speculative_model_dir, max_draft_len, num_eagle_layers, max_non_leaves_per_layer, eagle_choices
# speculative_model, max_draft_len, num_eagle_layers, max_non_leaves_per_layer, eagle_choices
# greedy_sampling,posterior_threshold, use_dynamic_tree and dynamic_tree_max_topK
# with the EagleDecodingConfig class

speculative_config = EagleDecodingConfig(
speculative_model_dir="yuhuili/EAGLE-Vicuna-7B-v1.3",
speculative_model="yuhuili/EAGLE-Vicuna-7B-v1.3",
max_draft_len=63,
num_eagle_layers=4,
max_non_leaves_per_layer=10,
Expand Down
4 changes: 2 additions & 2 deletions examples/llm-api/_tensorrt_engine/llm_eagle_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,12 @@ def main():
model = "lmsys/vicuna-7b-v1.3"

# The end user can customize the eagle decoding configuration by specifying the
# speculative_model_dir, max_draft_len, num_eagle_layers, max_non_leaves_per_layer, eagle_choices
# speculative_model, max_draft_len, num_eagle_layers, max_non_leaves_per_layer, eagle_choices
# greedy_sampling,posterior_threshold, use_dynamic_tree and dynamic_tree_max_topK
# with the EagleDecodingConfig class

speculative_config = EagleDecodingConfig(
speculative_model_dir="yuhuili/EAGLE-Vicuna-7B-v1.3",
speculative_model="yuhuili/EAGLE-Vicuna-7B-v1.3",
max_draft_len=63,
num_eagle_layers=4,
max_non_leaves_per_layer=10,
Expand Down
4 changes: 2 additions & 2 deletions examples/llm-api/_tensorrt_engine/llm_medusa_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,10 @@ def run_medusa_decoding(use_modelopt_ckpt=False, model_dir=None):
model = "lmsys/vicuna-7b-v1.3"

# The end user can customize the medusa decoding configuration by specifying the
# speculative_model_dir, max_draft_len, medusa heads num and medusa choices
# speculative_model, max_draft_len, medusa heads num and medusa choices
# with the MedusaDecodingConfig class
speculative_config = MedusaDecodingConfig(
speculative_model_dir="FasterDecoding/medusa-vicuna-7b-v1.3",
speculative_model="FasterDecoding/medusa-vicuna-7b-v1.3",
max_draft_len=63,
num_medusa_heads=4,
medusa_choices=[[0], [0, 0], [1], [0, 1], [2], [0, 0, 0], [1, 0], [0, 2], [3], [0, 3], [4], [0, 4], [2, 0], \
Expand Down
2 changes: 1 addition & 1 deletion examples/llm-api/llm_speculative_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def run_MTP(model: Optional[str] = None):
def run_Eagle3():
spec_config = EagleDecodingConfig(
max_draft_len=3,
speculative_model_dir="yuhuili/EAGLE3-LLaMA3.1-Instruct-8B",
speculative_model="yuhuili/EAGLE3-LLaMA3.1-Instruct-8B",
eagle3_one_model=True)

kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.8)
Expand Down
6 changes: 3 additions & 3 deletions examples/llm-api/quickstart_advanced.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,11 +220,11 @@ def setup_llm(args, **kwargs):
relaxed_topk=args.relaxed_topk,
relaxed_delta=args.relaxed_delta,
mtp_eagle_one_model=args.use_one_model,
speculative_model_dir=args.model_dir)
speculative_model=args.model_dir)
elif spec_decode_algo == "EAGLE3":
spec_config = EagleDecodingConfig(
max_draft_len=args.spec_decode_max_draft_len,
speculative_model_dir=args.draft_model_dir,
speculative_model=args.draft_model_dir,
eagle3_one_model=args.use_one_model,
eagle_choices=args.eagle_choices,
use_dynamic_tree=args.use_dynamic_tree,
Expand All @@ -234,7 +234,7 @@ def setup_llm(args, **kwargs):
elif spec_decode_algo == "DRAFT_TARGET":
spec_config = DraftTargetDecodingConfig(
max_draft_len=args.spec_decode_max_draft_len,
speculative_model_dir=args.draft_model_dir)
speculative_model=args.draft_model_dir)
elif spec_decode_algo == "NGRAM":
spec_config = NGramDecodingConfig(
max_draft_len=args.spec_decode_max_draft_len,
Expand Down
6 changes: 3 additions & 3 deletions examples/models/core/qwen/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -841,8 +841,8 @@ Qwen3 now supports Eagle3 (Speculative Decoding with Eagle3). To enable Eagle3 o
Set the decoding type to "Eagle" to enable Eagle3 speculative decoding.
- `speculative_config.max_draft_len: 3`
Set the maximum number of draft tokens generated per step (this value can be adjusted as needed).
- `speculative_config.speculative_model_dir: <EAGLE3_DRAFT_MODEL_PATH>`
Specify the path to the Eagle3 draft model (ensure the corresponding draft model weights are prepared).
- `speculative_config.speculative_model: <HUGGINGFACE ID / LOCAL PATH>`
Specify the Eagle3 draft model either as a Huggingface model ID or a local path. You can find ready-to-use Eagle3 draft models at https://huggingface.co/collections/nvidia/speculative-decoding-modules.

Currently, there are some limitations when enabling Eagle3:

Expand All @@ -857,7 +857,7 @@ enable_attention_dp: false
speculative_config:
decoding_type: Eagle
max_draft_len: 3
speculative_model_dir: <EAGLE3_DRAFT_MODEL_PATH>
speculative_model: <HUGGINGFACE ID / LOCAL PATH>
kv_cache_config:
enable_block_reuse: false
" >> ${path_config}
Expand Down
2 changes: 1 addition & 1 deletion tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -921,7 +921,7 @@ def create_draft_model_engine_maybe(
drafting_loop_wrapper = None

draft_model_engine = PyTorchModelEngine(
model_path=draft_spec_config.speculative_model_dir,
model_path=draft_spec_config.speculative_model,
llm_args=draft_llm_args,
mapping=dist_mapping,
attn_runtime_features=attn_runtime_features,
Expand Down
4 changes: 2 additions & 2 deletions tensorrt_llm/_torch/models/modeling_speculative.py
Original file line number Diff line number Diff line change
Expand Up @@ -887,7 +887,7 @@ def __init__(self, model: TModel, model_config: ModelConfig[TConfig]):
from tensorrt_llm._torch.models.checkpoints.mistral.config_loader import \
MistralConfigLoader
self.draft_config = MistralConfigLoader().load(
spec_config.speculative_model_dir,
spec_config.speculative_model,
mapping=model_config.mapping,
moe_backend=model_config.moe_backend,
moe_max_num_tokens=model_config.moe_max_num_tokens,
Expand All @@ -898,7 +898,7 @@ def __init__(self, model: TModel, model_config: ModelConfig[TConfig]):
self.draft_config.extra_attrs = model_config.extra_attrs
elif spec_config.eagle3_model_arch == "llama3":
self.draft_config = ModelConfig.from_pretrained(
model_config.spec_config.speculative_model_dir,
model_config.spec_config.speculative_model,
trust_remote_code=True,
attn_backend=model_config.attn_backend,
moe_backend=model_config.moe_backend,
Expand Down
2 changes: 1 addition & 1 deletion tensorrt_llm/_torch/pyexecutor/model_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ def init_meta_tensor(t: torch.Tensor):
if self.spec_config is not None and self.spec_config.spec_dec_mode.need_load_draft_weights(
):
weights = checkpoint_loader.load_weights(
self.spec_config.speculative_model_dir,
self.spec_config.speculative_model,
mapping=self.mapping)

draft_model_arch = model.draft_config.pretrained_config.architectures[
Expand Down
2 changes: 1 addition & 1 deletion tensorrt_llm/_torch/pyexecutor/py_executor_creator.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,7 +398,7 @@ def drafting_loop_wrapper(model):
draft_llm_args.load_format = LoadFormat.DUMMY

draft_model_engine = PyTorchModelEngine(
model_path=spec_config.speculative_model_dir,
model_path=spec_config.speculative_model,
llm_args=draft_llm_args,
mapping=mapping,
attn_runtime_features=attn_runtime_features,
Expand Down
Loading