Skip to content

Commit 3888cd8

Browse files
committed
Add Eagle3 decoding_type
Introduce speculative_config.decoding_type: Eagle3 for the PyTorch backend, warn when using Eagle as an alias, and reject Eagle3 on the TensorRT backend. Update docs/examples and add unit tests. Signed-off-by: Venky Ganesh <23023424+venkywonka@users.noreply.github.com>
1 parent 10a4571 commit 3888cd8

File tree

7 files changed

+111
-45
lines changed

7 files changed

+111
-45
lines changed

docs/source/blogs/tech_blog/blog11_GPT_OSS_Eagle3.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ kv_cache_config:
8484
enable_block_reuse: false
8585
free_gpu_memory_fraction: 0.8
8686
speculative_config:
87-
decoding_type: Eagle
87+
decoding_type: Eagle3
8888
max_draft_len: 3
8989
speculative_model_dir: /config/models/eagle/
9090
cuda_graph_config:

docs/source/blogs/tech_blog/blog6_Llama4_maverick_eagle_guide.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ docker run -d --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 \
6868
-p 8000:8000 --gpus=all -e "TRTLLM_ENABLE_PDL=1" \
6969
-v /path/to/maverick:/config/models/maverick -v /path/to/eagle:/config/models/eagle \
7070
docker.io/<username>/tensorrt_llm:main sh \
71-
-c "echo -e 'enable_autotuner: false\nenable_attention_dp: false\nenable_min_latency: true\ncuda_graph_config:\n max_batch_size: 8\nspeculative_config:\n decoding_type: Eagle\n max_draft_len: 3\n speculative_model_dir: /config/models/eagle\n eagle3_one_model: true\nkv_cache_config:\n enable_block_reuse: false' > c.yaml && \
71+
-c "echo -e 'enable_autotuner: false\nenable_attention_dp: false\nenable_min_latency: true\ncuda_graph_config:\n max_batch_size: 8\nspeculative_config:\n decoding_type: Eagle3\n max_draft_len: 3\n speculative_model_dir: /config/models/eagle\n eagle3_one_model: true\nkv_cache_config:\n enable_block_reuse: false' > c.yaml && \
7272
TRT_LLM_DISABLE_LOAD_WEIGHTS_IN_PARALLEL=True \
7373
trtllm-serve /config/models/maverick \
7474
--host 0.0.0.0 --port 8000 \

docs/source/features/speculative-decoding.md

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -125,16 +125,18 @@ llm = LLM("/path/to/target_model", speculative_config=speculative_config)
125125
Speculative decoding options must be specified via `--extra_llm_api_options config.yaml` for both `trtllm-bench` and `trtllm-serve`. All speculative decoding options can be specified in this YAML file. An additional `decoding_type` option is used to specify the type of speculation to use. The available options are:
126126

127127
* `MTP`
128-
* `Eagle` (for EAGLE 3)
128+
* `Eagle3` (EAGLE 3)
129129
* `NGram`
130130
* `DraftTarget`
131131

132+
> Note: `decoding_type: Eagle` is accepted as a PyTorch-backend alias for `Eagle3`, but `Eagle3` is preferred for clarity.
133+
132134
The rest of the argument names/valid values are the same as in their corresponding configuration class described in the Quick Start section. For example, a YAML configuration could look like this:
133135

134136
```
135137
disable_overlap_scheduler: true
136138
speculative_config:
137-
decoding_type: Eagle
139+
decoding_type: Eagle3
138140
max_draft_len: 4
139141
speculative_model: /path/to/draft/model
140142
```

docs/source/features/torch_compile_and_piecewise_cuda_graph.md

Lines changed: 35 additions & 37 deletions
Large diffs are not rendered by default.

examples/models/core/qwen/README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -837,8 +837,8 @@ settings for your specific use case.
837837

838838
Qwen3 now supports Eagle3 (Speculative Decoding with Eagle3). To enable Eagle3 on Qwen3, you need to set the following arguments when running `trtllm-bench` or `trtllm-serve`:
839839

840-
- `speculative_config.decoding_type: Eagle`
841-
Set the decoding type to "Eagle" to enable Eagle3 speculative decoding.
840+
- `speculative_config.decoding_type: Eagle3`
841+
Set the decoding type to `Eagle3` to enable Eagle3 speculative decoding.
842842
- `speculative_config.max_draft_len: 3`
843843
Set the maximum number of draft tokens generated per step (this value can be adjusted as needed).
844844
- `speculative_config.speculative_model_dir: <EAGLE3_DRAFT_MODEL_PATH>`
@@ -855,7 +855,7 @@ Example `extra-llm-api-config.yml` snippet for Eagle3:
855855
echo "
856856
enable_attention_dp: false
857857
speculative_config:
858-
decoding_type: Eagle
858+
decoding_type: Eagle3
859859
max_draft_len: 3
860860
speculative_model_dir: <EAGLE3_DRAFT_MODEL_PATH>
861861
kv_cache_config:

tensorrt_llm/llmapi/llm_args.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -729,6 +729,7 @@ def from_dict(cls, data: dict):
729729
"MTP": MTPDecodingConfig,
730730
"Medusa": MedusaDecodingConfig,
731731
"Eagle": EagleDecodingConfig,
732+
"Eagle3": Eagle3DecodingConfig,
732733
"Lookahead": LookaheadDecodingConfig,
733734
"NGram": NGramDecodingConfig,
734735
"DraftTarget": DraftTargetDecodingConfig,
@@ -927,6 +928,10 @@ def is_linear_tree(self) -> bool:
927928
return False
928929

929930

931+
class Eagle3DecodingConfig(EagleDecodingConfig):
932+
decoding_type: ClassVar[str] = "Eagle3"
933+
934+
930935
class SaveHiddenStatesDecodingConfig(DecodingBaseConfig):
931936
output_directory: str
932937
write_interval: int = 20
@@ -2422,9 +2427,15 @@ def validate_speculative_config(self):
24222427
decoding_mode=DecodingMode.Medusa(),
24232428
medusa_choices=self.speculative_config.medusa_choices)
24242429

2430+
elif isinstance(self.speculative_config, Eagle3DecodingConfig):
2431+
raise ValueError(
2432+
"speculative_config.decoding_type 'Eagle3' is only supported on the PyTorch backend. "
2433+
"Use decoding_type: Eagle with --backend tensorrt, or switch to --backend pytorch for Eagle3."
2434+
)
2435+
24252436
elif isinstance(self.speculative_config, EagleDecodingConfig):
24262437
assert self.speculative_config.max_draft_len > 0
2427-
assert self.speculative_config.speculative_model_dir is not None, "Path to EAGLE3 weights must be specified."
2438+
assert self.speculative_config.speculative_model_dir is not None, "Path to EAGLE weights must be specified."
24282439
self.build_config.max_draft_len = self.speculative_config.max_draft_len
24292440
self.build_config.speculative_decoding_mode = SpeculativeDecodingMode.EAGLE
24302441
eagle_config = _EagleConfig(
@@ -2940,6 +2951,10 @@ def validate_speculative_config(self):
29402951
f"support backend {self.backend}")
29412952

29422953
if isinstance(self.speculative_config, EagleDecodingConfig):
2954+
if type(self.speculative_config) is EagleDecodingConfig:
2955+
logger.warning(
2956+
"speculative_config.decoding_type 'Eagle' maps to Eagle3 in the PyTorch backend; "
2957+
"use 'Eagle3' to be explicit.")
29432958
assert self.speculative_config.max_draft_len > 0
29442959
assert self.speculative_config.speculative_model_dir is not None, "Path to EAGLE3 weights must be specified."
29452960
elif isinstance(self.speculative_config, NGramDecodingConfig):

tests/unittest/llmapi/test_llm_args.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,57 @@ def test_llm_args_with_pydantic_options(self):
139139
assert llm_args.max_seq_len == 128
140140

141141

142+
def test_decoding_type_eagle3_parses_to_eagle3_decoding_config():
143+
spec_cfg = DecodingBaseConfig.from_dict({
144+
"decoding_type":
145+
"Eagle3",
146+
"max_draft_len":
147+
3,
148+
"speculative_model_dir":
149+
"/path/to/draft/model",
150+
})
151+
assert isinstance(spec_cfg, Eagle3DecodingConfig)
152+
153+
154+
def test_decoding_type_eagle_warns_on_pytorch_backend(monkeypatch):
155+
import tensorrt_llm.llmapi.llm_args as llm_args_mod
156+
157+
warnings_seen: list[str] = []
158+
159+
def _capture_warning(msg, *args, **kwargs):
160+
warnings_seen.append(str(msg))
161+
162+
monkeypatch.setattr(llm_args_mod.logger, "warning", _capture_warning)
163+
164+
spec_cfg = DecodingBaseConfig.from_dict({
165+
"decoding_type":
166+
"Eagle",
167+
"max_draft_len":
168+
3,
169+
"speculative_model_dir":
170+
"/path/to/draft/model",
171+
})
172+
173+
TorchLlmArgs(model=llama_model_path, speculative_config=spec_cfg)
174+
175+
assert any("maps to Eagle3 in the PyTorch backend" in m
176+
for m in warnings_seen)
177+
178+
179+
def test_decoding_type_eagle3_errors_on_tensorrt_backend():
180+
spec_cfg = DecodingBaseConfig.from_dict({
181+
"decoding_type":
182+
"Eagle3",
183+
"max_draft_len":
184+
3,
185+
"speculative_model_dir":
186+
"/path/to/draft/model",
187+
})
188+
with pytest.raises(ValueError,
189+
match="only supported on the PyTorch backend"):
190+
TrtLlmArgs(model=llama_model_path, speculative_config=spec_cfg)
191+
192+
142193
def check_defaults(py_config_cls, pybind_config_cls):
143194
py_config = py_config_cls()
144195
pybind_config = pybind_config_cls()

0 commit comments

Comments
 (0)