Skip to content

Commit 34a3f35

Browse files
committed
[None][fix] Use dynamic tree SpecTreeManager in kv_cache_relocation test + add docs
Switch SpecTreeManager in test_llama_verification_with_kv_cache_relocation from static tree (use_dynamic_tree=False) to dynamic tree mode, removing the eagle_choices parameter. Fixes RuntimeError on H100 (sm<100) where flat single-level eagle_choices produced empty top_k_list tensors. Also add EAGLE3 dynamic tree mode documentation to speculative-decoding.md per reviewer request. Signed-off-by: qgai <qgai@nvidia.com>
1 parent 2ac3662 commit 34a3f35

File tree

2 files changed

+49
-9
lines changed

2 files changed

+49
-9
lines changed

docs/source/features/speculative-decoding.md

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ llm = LLM("/path/to/target_model", speculative_config=speculative_config, disabl
2828
### EAGLE 3
2929

3030
The EAGLE 3 algorithm is described in the paper [EAGLE-3: Scaling up Inference Acceleration of Large Language Models via Training-Time Test](https://arxiv.org/pdf/2503.01840).
31-
TRT-LLM supports a modified version of the algorithm presented in the paper: tree structures for draft sequences are not supported. Instead, each request uses a single sequence of draft tokens with length `max_draft_len`.
31+
By default, each request uses a single sequence (linear chain) of draft tokens with length `max_draft_len`. Optionally, dynamic tree draft generation can be enabled to improve acceptance rates — see [Dynamic Tree Mode](#dynamic-tree-mode) below.
3232

3333
The following draft model checkpoints can be used for EAGLE 3:
3434
* Llama 3 variants: [use the checkpoints from the authors of the original EAGLE 3 paper](https://huggingface.co/yuhuili).
@@ -50,6 +50,36 @@ llm = LLM(model, speculative_config=speculative_config)
5050

5151
EAGLE 3 can be combined with the [Suffix Automaton enhancement](#suffix-automaton-sa-enhancement) for improved acceptance rates on repetitive content. See the SA section below for details.
5252

53+
#### Dynamic Tree Mode
54+
55+
Dynamic tree mode enables tree-structured draft generation for EAGLE 3, where the drafter expands multiple candidate tokens at each layer instead of a single token. This can improve acceptance rates compared to linear drafting at the cost of additional compute per generation step.
56+
57+
To enable dynamic tree mode, set `use_dynamic_tree=True` on the `Eagle3DecodingConfig` and provide the following parameters:
58+
59+
* `use_dynamic_tree` (`bool`): Enables dynamic tree draft generation. Mutually exclusive with `eagle_choices` (static tree).
60+
* `dynamic_tree_max_topK` (`int`): Maximum number of tokens to expand per node at each draft layer.
61+
* `max_total_draft_tokens` (`int`, optional): Total draft token budget for the tree. Must satisfy `max_draft_len <= max_total_draft_tokens <= dynamic_tree_max_topK * max_draft_len`. Defaults to `dynamic_tree_max_topK * max_draft_len` if not set.
62+
* `max_batch_size` (`int`): Required when `use_dynamic_tree=True` for pre-allocating dynamic tree CUDA buffers.
63+
64+
```python
65+
from tensorrt_llm.llmapi import Eagle3DecodingConfig
66+
67+
speculative_config = Eagle3DecodingConfig(
68+
max_draft_len=6,
69+
speculative_model="yuhuili/EAGLE3-LLaMA3.1-Instruct-8B",
70+
use_dynamic_tree=True,
71+
dynamic_tree_max_topK=10,
72+
max_total_draft_tokens=60,
73+
max_batch_size=4,
74+
)
75+
76+
llm = LLM("/path/to/target_model", speculative_config=speculative_config)
77+
```
78+
79+
```{note}
80+
Dynamic tree mode is currently **not supported** for models that use sliding window attention or MLA (Multi-Latent Attention), such as DeepSeek and gpt-oss models.
81+
```
82+
5383
### NGram
5484

5585
The NGram method is an implementation of [this Prompt Lookup Decoding algorithm](https://github.com/apoorvumang/prompt-lookup-decoding).
@@ -199,6 +229,18 @@ speculative_config:
199229
speculative_model: /path/to/draft/model
200230
```
201231
232+
```yaml
233+
# Dynamic tree mode
234+
speculative_config:
235+
decoding_type: Eagle3
236+
max_draft_len: 6
237+
speculative_model: /path/to/eagle3_model
238+
use_dynamic_tree: true
239+
dynamic_tree_max_topK: 10
240+
max_total_draft_tokens: 60
241+
max_batch_size: 4
242+
```
243+
202244
```yaml
203245
# SA combination: enable Suffix Automaton enhancement with any supported technique
204246
speculative_config:

tests/unittest/_torch/modeling/test_modeling_llama.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -611,13 +611,12 @@ def run_forward(input_ids, position_ids, attn_metadata):
611611
spec_metadata_phase1 = None
612612
if is_tree_phase1:
613613
max_draft_1 = gen_input_ids_1.size(-1) - 1
614-
eagle_choices_phase1 = [[i] for i in range(max_draft_1)]
615614
spec_tree_mgr_phase1 = SpecTreeManager(
616615
max_num_requests=1,
617-
use_dynamic_tree=False,
616+
use_dynamic_tree=True,
618617
max_total_draft_tokens=max_draft_1,
619618
max_draft_len=max_draft_1,
620-
eagle_choices=eagle_choices_phase1,
619+
eagle_choices=None,
621620
dynamic_tree_max_topK=10,
622621
)
623622
spec_metadata_phase1 = SpecMetadata(
@@ -630,7 +629,7 @@ def run_forward(input_ids, position_ids, attn_metadata):
630629
batch_size=batch_size,
631630
is_spec_decoding_enabled=is_spec_decoding_enabled,
632631
is_spec_dec_tree=is_tree_phase1,
633-
is_spec_dec_dynamic_tree=False,
632+
is_spec_dec_dynamic_tree=is_tree_phase1,
634633
max_draft_len=gen_input_ids_1.size(-1) - 1,
635634
max_total_draft_tokens=gen_input_ids_1.size(-1) - 1,
636635
model_is_wrapped=False,
@@ -687,13 +686,12 @@ def run_forward(input_ids, position_ids, attn_metadata):
687686
spec_metadata_ref = None
688687
if is_tree_ref:
689688
max_draft_ref = gen_input_ids_ref.size(-1) - 1
690-
eagle_choices_ref = [[i] for i in range(max_draft_ref)]
691689
spec_tree_mgr_ref = SpecTreeManager(
692690
max_num_requests=1,
693-
use_dynamic_tree=False,
691+
use_dynamic_tree=True,
694692
max_total_draft_tokens=max_draft_ref,
695693
max_draft_len=max_draft_ref,
696-
eagle_choices=eagle_choices_ref,
694+
eagle_choices=None,
697695
dynamic_tree_max_topK=10,
698696
)
699697
spec_metadata_ref = SpecMetadata(
@@ -706,7 +704,7 @@ def run_forward(input_ids, position_ids, attn_metadata):
706704
batch_size=batch_size,
707705
is_spec_decoding_enabled=is_spec_decoding_enabled,
708706
is_spec_dec_tree=is_tree_ref,
709-
is_spec_dec_dynamic_tree=False,
707+
is_spec_dec_dynamic_tree=is_tree_ref,
710708
max_draft_len=gen_input_ids_ref.size(-1) - 1,
711709
max_total_draft_tokens=gen_input_ids_ref.size(-1) - 1,
712710
model_is_wrapped=False,

0 commit comments

Comments
 (0)