Skip to content

[TRTLLM-11540][feat] Add EAGLE3 dynamic tree speculative decoding support#12062

Open
sunnyqgg wants to merge 8 commits intoNVIDIA:mainfrom
sunnyqgg:add_dyanmic_tree_support_one_model
Open

[TRTLLM-11540][feat] Add EAGLE3 dynamic tree speculative decoding support#12062
sunnyqgg wants to merge 8 commits intoNVIDIA:mainfrom
sunnyqgg:add_dyanmic_tree_support_one_model

Conversation

@sunnyqgg
Copy link
Copy Markdown
Collaborator

@sunnyqgg sunnyqgg commented Mar 10, 2026

Summary

  • Add dynamic tree speculative decoding support for EAGLE3 (both one-model and two-model flows)
  • Implement Eagle3OneModelDynamicTreeWorker and Eagle3OneModelDynamicTreeSampler for one-model dynamic tree inference
  • Add CUDA kernels for dynamic tree operations (expand, gather, update)
  • Support growing context in dynamic tree mode for improved accept rates

Changes

  • New: tensorrt_llm/_torch/speculative/eagle3_dynamic_tree.py — One-model dynamic tree worker and sampler
  • New: tensorrt_llm/_torch/speculative/dynamic_tree_ops.py — Python wrappers for dynamic tree CUDA ops
  • New: cpp/tensorrt_llm/kernels/speculativeDecoding/dynamicTreeKernels.cu/.h — CUDA kernels
  • New: cpp/tensorrt_llm/thop/dynamicTreeOp.cpp — Torch custom op bindings
  • Modified: tensorrt_llm/_torch/speculative/eagle3.py — Refactored Eagle3OneModelWorker with dispatch pattern for linear vs dynamic tree
  • Modified: tensorrt_llm/_torch/speculative/utils.py — Route to dynamic tree components when use_dynamic_tree=True
  • Modified: tensorrt_llm/_torch/speculative/drafting_loops.py — Two-model dynamic tree drafting loop
  • Modified: tensorrt_llm/_torch/speculative/model_drafter.py — Dynamic tree spec tree manager integration
  • Modified: tensorrt_llm/_torch/speculative/spec_tree_manager.py — Support dynamic tree token organization
  • Modified: tensorrt_llm/_torch/pyexecutor/model_engine.py — Dynamic tree detection for target model
  • Modified: tensorrt_llm/_torch/pyexecutor/sampler.py — Dynamic tree batch verification
  • Modified: tensorrt_llm/_torch/models/modeling_speculative.py — Hidden states handling for dynamic tree
  • Modified: tensorrt_llm/llmapi/llm_args.py — Configuration validation for dynamic tree
  • Modified: Attention backend files for dynamic tree metadata support

Test plan

  • Unit tests pass (tests/unittest/_torch/speculative/)
  • One-model dynamic tree EAGLE3 inference matches two-model accept rates
  • Two-model dynamic tree EAGLE3 inference works correctly
  • No regression in standard (non-dynamic-tree) EAGLE3 flows

Summary by CodeRabbit

  • New Features

    • Added dynamic tree speculative decoding support for Eagle3, enabling efficient parallel token prediction with flexible tree structures
    • Added streaming generation support with configurable max draft tokens and speculative decoding parameters
    • Enhanced token verification with CUDA-accelerated kernels for improved inference performance
  • Configuration

    • Added --max_total_draft_tokens parameter for controlling total draft token budget
    • Added --streaming flag for real-time token streaming output

@sunnyqgg
Copy link
Copy Markdown
Collaborator Author

/bot run

@sunnyqgg sunnyqgg marked this pull request as draft March 10, 2026 04:33
@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #38374 [ run ] triggered by Bot. Commit: b15f7d6 Link to invocation

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Mar 10, 2026

📝 Walkthrough

Walkthrough

This pull request introduces dynamic tree-based speculative decoding for EAGLE3 inference with CUDA-accelerated kernels. It adds tree construction and greedy verification kernels, integrates them with Python/PyTorch layers, implements dynamic tree sampling and acceptance logic, and extends the Eagle3 resource manager and worker infrastructure to support dynamic tree mode alongside static tree mode.

Changes

Cohort / File(s) Summary
CUDA Kernels & Interface
cpp/tensorrt_llm/kernels/speculativeDecoding/dynamicTreeKernels.cu, cpp/tensorrt_llm/kernels/speculativeDecoding/dynamicTreeKernels.h
New kernels for dynamic tree construction (packed and non-packed treeMask variants) and greedy tree verification. Includes kernel launchers (invokeBuildDynamicTree, invokeVerifyDynamicTreeGreedy) and TreeMaskMode enum (QLEN_ONLY, QLEN_ONLY_BITPACKING).
Torch C++ Extension
cpp/tensorrt_llm/thop/dynamicTreeOp.cpp, cpp/tensorrt_llm/thop/CMakeLists.txt
PyTorch custom operators binding CUDA kernels: build_dynamic_tree_op and verify_dynamic_tree_greedy_op. Includes input validation, buffer initialization, and kernel invocation via tensor interfaces.
Dynamic Tree Operations
tensorrt_llm/_torch/speculative/dynamic_tree_ops.py
Core Python abstraction layer introducing DynamicTreeBuffers, VerifyTreeResults, DynamicTreeOpsConverter, and factory function create_dynamic_tree_ops_converter for managing tree construction and verification with preallocated buffers and error handling.
Drafting Loop Infrastructure
tensorrt_llm/_torch/speculative/drafting_loops.py
Renames existing TreeDraftingLoopWrapper to StaticTreeDraftingLoopWrapper and introduces new DynamicTreeDraftingLoopWrapper for dynamic-tree drafting with per-layer sampling, extensive buffering, and CUDA kernel integration.
Dynamic Tree Sampling & Worker
tensorrt_llm/_torch/speculative/eagle3_dynamic_tree.py
New module with Eagle3OneModelDynamicTreeSampler and Eagle3OneModelDynamicTreeWorker implementing dynamic-tree verification, KV-cache management, draft token generation, and tree mask/position preparation with extensive internal helpers (sample_dynamic, dt_update_draft_tokens_and_scores, dt_resampling_final_draft_tokens, etc.).
Sampler Integration
tensorrt_llm/_torch/pyexecutor/sampler.py
Adds batch-driven dynamic tree verification (_batch_verify_dynamic_tree) and per-request processing (_process_draft_tokens_dynamic_tree) integrated into main sampling flow. Includes dynamic tree result handling and fallback to greedy path when unavailable.
Eagle3 Configuration & Management
tensorrt_llm/_torch/speculative/eagle3.py, tensorrt_llm/_torch/speculative/spec_tree_manager.py
Introduces Eagle3OneModelDynamicTreeResourceManager and expands Eagle3OneModelSpecMetadata with use_dynamic_tree and eagle_choices fields. Adds dynamic-tree buffer allocation, eagle_paths selection logic, spec_dec_packed_mask computation, and drafter-model offsets in SpecTreeManager.
Attention Backend Interface Updates
tensorrt_llm/_torch/attention_backend/interface.py, tensorrt_llm/_torch/attention_backend/sparse/dsa.py, tensorrt_llm/_torch/attention_backend/trtllm.py
Updated update_spec_dec_param signatures: removed spec_metadata and spec_decoding_tensor parameters, added is_target_model flag. Backend implementations now reshape position_offsets for dynamic trees and handle spec decoding buffers with spec_tree_manager integration, including support for drafter-layer paths.
Model & Executor Updates
tensorrt_llm/_torch/models/modeling_speculative.py, tensorrt_llm/_torch/pyexecutor/model_engine.py, tensorrt_llm/_torch/pyexecutor/py_executor.py, tensorrt_llm/_torch/pyexecutor/py_executor_creator.py
Minor refactoring in Eagle3ForCausalLM; removed spec_decoding_tensor from PyTorchModelEngine; added warmup completion flag in PyExecutor; and expanded py_executor_creator to conditionally select DynamicTreeDraftingLoopWrapper or StaticTreeDraftingLoopWrapper based on EagleDecodingConfig flags.
Model Drafter
tensorrt_llm/_torch/speculative/model_drafter.py
Added dynamic-tree buffer handling in static-draft-output path. Extracts and propagates retrieve_index, retrieve_next_token, retrieve_next_sibling from tree structures to spec_tree_manager per-request state. Updated prepare_draft_tokens signature to require ResourceManager explicitly.
Configuration & Utilities
tensorrt_llm/llmapi/llm_args.py, tensorrt_llm/_torch/speculative/utils.py
Enhanced EagleDecodingConfig.validate_eagle_config with unified dynamic-tree validation: enforces dynamic_tree_max_topK, requires eagle_choices=None for dynamic mode, defaults/validates max_total_draft_tokens. Expanded utils.py to route to dynamic-tree variants (ResourceManager, Sampler, Worker) when use_dynamic_tree is enabled.
Example & Tests
examples/llm-api/quickstart_advanced.py, tests/unittest/_torch/speculative/test_draft_token_prepare_for_generation.py, tests/unittest/_torch/speculative/test_draft_token_tree_sampling.py
Added --max_total_draft_tokens and --streaming CLI arguments to quickstart_advanced. Test imports updated to use StaticTreeDraftingLoopWrapper and DynamicTreeDraftingLoopWrapper; added new test functions for dynamic tree updates and restructuring with ModelDrafter integration.

Sequence Diagram(s)

sequenceDiagram
    participant Python as Python Layer
    participant Sampler as TorchSampler
    participant DTree as DynamicTreeOpsConverter
    participant Kernel as CUDA Kernel
    participant Buf as GPU Buffers

    Python->>Sampler: update() with requests
    Sampler->>Sampler: _batch_verify_dynamic_tree(requests, tokens)
    Sampler->>DTree: verify_dynamic_tree_greedy(draft_tokens, logits, tree_buffers)
    DTree->>DTree: compute target predictions from logits
    DTree->>Kernel: invoke verify_dynamic_tree_greedy_op()
    Kernel->>Buf: read: candidates, retrieve_index, retrieve_next_sibling, targetPredict
    Kernel->>Buf: greedy tree traversal per batch
    Kernel->>Buf: write: predicts, acceptIndex, acceptTokenNum
    Buf-->>Kernel: results
    Kernel-->>DTree: returns VerifyTreeResults
    DTree-->>Sampler: per-slot (num_accepted_tokens, accept_index)
    Sampler->>Sampler: _process_draft_tokens_dynamic_tree() per request
    Sampler-->>Python: updated tokens and finish reasons
Loading
sequenceDiagram
    participant Worker as Eagle3DynamicTreeWorker
    participant DraftModel as Draft Model
    participant DTree as DynamicTreeOpsConverter
    participant Kernel as CUDA Kernel
    participant Cache as KV Cache

    Worker->>Worker: _forward_draft_loop(initial context)
    Worker->>DraftModel: forward(input_ids, position_ids)
    DraftModel-->>Worker: logits
    Worker->>Worker: sample_dynamic(logits, topk)
    Worker->>Worker: dt_update_draft_tokens_and_scores()
    Worker->>DTree: build_dynamic_tree(parent_list, topk_indices, tree_buffers)
    DTree->>Kernel: invoke build_dynamic_tree_op()
    Kernel->>Kernel: construct left-child/right-sibling tree
    Kernel->>Kernel: compute per-node attention masks (treeMask)
    Kernel->>Kernel: compute absolute positions
    Kernel-->>DTree: DynamicTreeBuffers (tree_mask, positions, retrieve_index, etc.)
    DTree-->>Worker: tree structure ready
    Worker->>Worker: dt_prepare_tree_mask_and_position_offset()
    Worker->>DraftModel: forward(growing context with tree topology)
    DraftModel->>Cache: update KV cache with tree positions
    DraftModel-->>Worker: logits per tree node
    Worker->>Worker: _sample_and_accept_dynamic_tree(logits)
Loading
sequenceDiagram
    participant App as Application
    participant Executor as PyExecutor
    participant ResourceMgr as Eagle3OneModelDynamicTreeResourceManager
    participant Worker as Eagle3OneModelDynamicTreeWorker
    participant Sampler as Eagle3OneModelDynamicTreeSampler

    App->>Executor: initialize with EagleDecodingConfig(use_dynamic_tree=True)
    Executor->>ResourceMgr: create with SpecTreeManager(use_dynamic_tree=True)
    Executor->>Worker: initialize with spec_config
    Executor->>Sampler: initialize with spec_config
    App->>Executor: generate(requests)
    Executor->>Worker: forward(context & draft loop)
    Worker->>Worker: _forward_dynamic_tree_draft_loop()
    Worker-->>Executor: draft_tokens, dynamic_tree_buffers, accepted_draft_indices
    Executor->>Sampler: sample_and_accept_draft_tokens(logits, buffers)
    Sampler->>Sampler: verify with dynamic tree buffers
    Sampler-->>Executor: accepted tokens, accepted indices
    Executor->>ResourceMgr: get_needed_resource_to_completion(request)
    ResourceMgr-->>Executor: resource estimates
    Executor->>Worker: prepare_1st_drafter_inputs(with tree topology targets)
    Executor-->>App: generated tokens
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~75 minutes

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 38.89% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Description check ✅ Passed The PR description provides a comprehensive summary of changes, clear objectives, and test coverage details.
Title check ✅ Passed The title clearly and specifically summarizes the main change: adding EAGLE3 dynamic tree speculative decoding support. It is concise, directly related to the substantial feature additions across kernels, ops, and model layers.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 17

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
tensorrt_llm/_torch/pyexecutor/sampler.py (1)

3346-3386: ⚠️ Potential issue | 🔴 Critical

Zero-draft dynamic-tree requests currently fall into the static-tree verifier.

_batch_verify_dynamic_tree() explicitly skips requests whose draft length is 0, but the fallback here still calls process_draft_tokens(). With a non-null spec_tree_manager, that dispatches to _process_draft_tokens_tree(), which assumes a populated draft tree. If a dynamic-tree request reaches this branch with no drafts, it will fail instead of just emitting the verified token.

🛠️ One possible guard
-                if req.py_seq_slot in dynamic_tree_results:
+                if req.py_seq_slot in dynamic_tree_results:
                     num_accepted = self._process_draft_tokens_dynamic_tree(
                         req, new_tokens_list, finish_reasons, dynamic_tree_results[req.py_seq_slot]
                     )
-
+                elif spec_tree_manager is not None and spec_tree_manager.use_dynamic_tree:
+                    num_accepted = self._process_draft_tokens_greedy(
+                        req, new_tokens=new_tokens_list, finish_reasons=finish_reasons
+                    )
                 else:
                     num_accepted = self.process_draft_tokens(
                         req,
                         new_tokens_tensor=new_tokens,
                         new_tokens_list=new_tokens_list,
🧹 Nitpick comments (5)
tensorrt_llm/_torch/pyexecutor/py_executor_creator.py (1)

435-437: Keep the wrapper imports namespaced here too.

Since this is new dispatch code, please import the drafting_loops module and reference these wrappers from that module rather than importing the classes directly. As per coding guidelines, When importing in Python, always maintain the namespace. Import the module, not individual classes or functions.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/_torch/pyexecutor/py_executor_creator.py` around lines 435 -
437, Replace the direct class imports DynamicTreeDraftingLoopWrapper,
LinearDraftingLoopWrapper, StaticTreeDraftingLoopWrapper with a namespaced
module import for drafting_loops and update all references to use
drafting_loops.DynamicTreeDraftingLoopWrapper,
drafting_loops.LinearDraftingLoopWrapper, and
drafting_loops.StaticTreeDraftingLoopWrapper (e.g., where these classes are used
in the dispatch/registration code inside py_executor_creator.py) so the module
is imported, not individual classes.
tests/unittest/_torch/speculative/test_draft_token_tree_sampling.py (1)

8-9: Keep the drafting-loop import namespaced.

Please import the drafting_loops module and resolve StaticTreeDraftingLoopWrapper through that namespace here instead of importing the class directly. As per coding guidelines, When importing in Python, always maintain the namespace. Import the module, not individual classes or functions.

Also applies to: 56-56

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/unittest/_torch/speculative/test_draft_token_tree_sampling.py` around
lines 8 - 9, Replace the direct class import with a module-level import for the
drafting_loops module and qualify the class through that namespace: change the
current "from tensorrt_llm._torch.speculative.drafting_loops import
StaticTreeDraftingLoopWrapper" to "import
tensorrt_llm._torch.speculative.drafting_loops as drafting_loops" (or "from
tensorrt_llm._torch.speculative import drafting_loops") and update all usages to
drafting_loops.StaticTreeDraftingLoopWrapper (also fix the similar import/usage
at the other occurrence noted).
tensorrt_llm/_torch/attention_backend/interface.py (1)

371-381: Making trailing parameters keyword-only would improve API safety, but it is not required.

The code is already safe: all call sites either use keyword arguments (model_engine.py:3552, test cases) or correctly pass all 9 positional arguments in order (sparse/dsa.py:541). No stale callers with 7–8 positional arguments (which would silently misbind after inserting is_target_model) exist in the codebase.

If keyword-only enforcement is desired for defensiveness, add * before is_target_model:

def update_spec_dec_param(
        self,
        batch_size,
        is_spec_decoding_enabled,
        is_spec_dec_tree,
        is_spec_dec_dynamic_tree,
        max_draft_len,
        max_total_draft_tokens,
+       *,
        is_target_model: bool = True,
        model_is_wrapped: bool = False,
        spec_tree_manager: Optional['SpecTreeManager'] = None):

This prevents future positional misuse and makes the intent explicit, but the current codebase is already compliant.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/_torch/attention_backend/interface.py` around lines 371 - 381,
The function update_spec_dec_param currently accepts many trailing
boolean/optional parameters positionally; make is_target_model,
model_is_wrapped, and spec_tree_manager keyword-only by inserting a bare *
before is_target_model in the signature so callers cannot accidentally bind
those flags positionally—update the signature in the update_spec_dec_param
definition and adjust any internal references accordingly (no other logic
changes).
tensorrt_llm/_torch/attention_backend/trtllm.py (1)

502-509: Consider clarifying the reshape assumption.

The reshape from 1D [max_num_requests * N] to 2D [max_num_requests, N] assumes the 1D tensor was allocated with exactly max_num_requests * (max_total_draft_tokens + 1) elements. This is correct based on line 1463-1465, but the implicit coupling between allocation and reshape could be fragile.

Consider adding an assertion or comment to make this contract explicit:

📝 Suggestion for defensive check
         # For dynamic tree, reshape 1D position_offsets to 2D for C++ kernel compatibility
         position_offsets_for_cpp = self.spec_decoding_position_offsets
         if (self.spec_decoding_position_offsets is not None
                 and self.spec_decoding_position_offsets.dim() == 1):
             # Reshape 1D [max_num_requests * N] to 2D [max_num_requests, N]
             # C++ kernel requires 2D to extract max_generation_length from sizes()[1]
+            assert self.spec_decoding_position_offsets.numel() % self.max_num_requests == 0, \
+                "1D position_offsets size must be divisible by max_num_requests"
             position_offsets_for_cpp = self.spec_decoding_position_offsets.view(
                 self.max_num_requests, -1)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/_torch/attention_backend/trtllm.py` around lines 502 - 509, The
reshape from 1D to 2D (position_offsets_for_cpp based on
self.spec_decoding_position_offsets) assumes the 1D tensor length equals
max_num_requests * N; add a defensive check before the view to assert that
self.spec_decoding_position_offsets.numel() is divisible by
self.max_num_requests and (optionally) equals self.max_num_requests *
(self.max_total_draft_tokens + 1) (or raise a clear error mentioning
spec_decoding_position_offsets and max_num_requests) so the implicit
allocation/reshape contract in trtllm.py is explicit and fails fast when
violated.
tensorrt_llm/_torch/pyexecutor/sampler.py (1)

2626-2668: Materialize accept_index before the per-request Python loop.

Line 2649's accept_index[j].item() performs a device read for every accepted token in a hot request loop. Convert the accepted indices once, then iterate over a Python list here, matching the existing new_tokens.tolist() pattern. Based on learnings: In files under tensorrt_llm/_torch/pyexecutor, avoid accessing torch.Tensor objects inside for-loops when iterating over requests. Convert batched tensors to Python lists beforehand using tensor.tolist(), and then iterate over those lists.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/_torch/pyexecutor/sampler.py` around lines 2626 - 2668, The loop
in _process_draft_tokens_dynamic_tree repeatedly calls accept_index[j].item(),
causing device reads per-iteration; materialize accept_index to a Python list
once before the request loop (e.g. accept_indices = accept_index.tolist() or
accept_index.cpu().tolist() and cast elements to int), then iterate over
accept_indices for add_token and finish_if_reason calls, and compute
request.py_num_accepted_draft_tokens_indices from that list by subtracting 1 for
positions after the root; keep using the same symbols (accept_index ->
accept_indices, _process_draft_tokens_dynamic_tree, add_token, finish_if_reason,
request.py_num_accepted_draft_tokens_indices) so you only replace tensor
indexing with list indexing.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@cpp/tensorrt_llm/kernels/speculativeDecoding/dynamicTreeKernels.cu`:
- Around line 112-182: The kernel reuses preallocated topology buffers but
doesn’t clear stale data; before building the tree (inside the tid==0 branch of
the dynamic tree builder) explicitly reinitialize retrieveNextToken and
retrieveNextSibling entries for this batch (bid) to -1 for all draftTokenNum
slots, and clear all words of treeMask for this batch (not just word 0); also
ensure positions/retrieveIndex for all slots are set to sane defaults if needed.
Locate the tid==0 block that sets positions[bid * draftTokenNum] and the loop
that writes retrieveIndex/retrieveNextToken/retrieveNextSibling and add the
resets there (and mirror the same full-reset logic in the other build region
referenced around lines 245-317).
- Around line 191-214: The ancestor-walk loop can run past bounds when a parent
lookup misses: after the for-loop that searches selectedIndex for tokenIdx
(using curPosition and draftTokenNum) add a guard to detect "not found"
(curPosition == draftTokenNum) and break the while loop to avoid reading/writing
past selectedIndex/treeMask; apply the same defensive check to the equivalent
ancestor-walk logic around the other block referenced (uses the same symbols:
treeMask, tokenTreeIdx, curPosition, selectedIndex, draftTokenNum, parentList,
parentTbIdx, bid, topK) so both paths stop if the parent was not resampled.

In `@cpp/tensorrt_llm/kernels/speculativeDecoding/dynamicTreeKernels.h`:
- Line 17: The header currently uses `#pragma` once but must follow the repo guard
convention; replace the pragma with a preprocessor include guard named
TRTLLM_DYNAMICTREEKERNELS_H (matching the filename dynamicTreeKernels.h in ALL
CAPS) by adding `#ifndef` TRTLLM_DYNAMICTREEKERNELS_H / `#define`
TRTLLM_DYNAMICTREEKERNELS_H at the top and a matching `#endif` at the bottom,
ensuring no directory names or trailing underscores are used and keeping the
rest of the file (dynamicTreeKernels.h) unchanged.

In `@cpp/tensorrt_llm/thop/dynamicTreeOp.cpp`:
- Around line 34-67: build_dynamic_tree_op (and its sibling
verify_dynamic_tree_greedy_op) currently access raw data pointers and call
at::cuda::getCurrentCUDAStream() without validating devices, dtypes, shapes or
the treeMaskMode enum; add TORCH_CHECKs to ensure all input/output tensors
(parentList, selectedIndex, treeMask, positions, retrieveIndex,
retrieveNextToken, retrieveNextSibling, verifiedSeqLen) are CUDA tensors
(is_cuda()), are on the same device (device.index() equality), and have the
expected scalar types (parentList/selectedIndex int64,
positions/retrieveIndex/retrieveNextToken/retrieveNextSibling/verifiedSeqLen
int32 as used by data_ptr<int32_t/int64_t>()), verify output shapes (batchSize,
numDraftTokens-1, etc.) before zero_/fill_, and check treeMaskMode is within the
valid tk::TreeMaskMode range before static_cast; perform these checks at the
start of build_dynamic_tree_op and verify_dynamic_tree_greedy_op so that
tk::invokeBuildDynamicTree and related kernel calls only receive validated
tensors and enum values.

In `@tensorrt_llm/_torch/attention_backend/sparse/dsa.py`:
- Line 540: The forward reference 'SpecTreeManager' used in the signature
(spec_tree_manager: Optional['SpecTreeManager']) is not imported under
TYPE_CHECKING; add "from tensorrt_llm._torch.speculative.spec_tree_manager
import SpecTreeManager" to the existing TYPE_CHECKING import block (alongside
any existing imports such as DecodingBaseConfig) so the name is resolved for
type checking and Ruff F821 is fixed.

In `@tensorrt_llm/_torch/pyexecutor/py_executor_creator.py`:
- Around line 444-461: The leftover variable use_tree_drafter is assigned
earlier but never used after splitting into static_tree_drafter and
dynamic_tree_drafter, causing a linter F841; remove the unused use_tree_drafter
assignment (or fold its logic into the existing predicates) so only
static_tree_drafter and dynamic_tree_drafter derived from draft_spec_config
(EagleDecodingConfig) remain, leaving the branching that returns
StaticTreeDraftingLoopWrapper and DynamicTreeDraftingLoopWrapper unchanged
(references: use_tree_drafter, static_tree_drafter, dynamic_tree_drafter,
draft_spec_config, spec_config, StaticTreeDraftingLoopWrapper,
DynamicTreeDraftingLoopWrapper).

In `@tensorrt_llm/_torch/speculative/drafting_loops.py`:
- Around line 709-718: The code unconditionally overwrites return_draft_logits
with zeros losing real collected logits; change the logic in drafting_loops.py
so that you only allocate the zero tensor as a fallback when return_draft_logits
is missing or has an incompatible shape (e.g., check if return_draft_logits is
None or return_draft_logits.shape != (self.max_total_draft_tokens, batch_size,
vocab_size)); otherwise preserve the existing return_draft_logits from the last
draft layer; when allocating the fallback ensure dtype/device match
(torch.float32 and 'cuda') and add a brief comment referencing
tokens_accumulated to indicate this is a temporary fallback until per-layer
gathering is implemented.
- Around line 1164-1180: The attn_metadata.use_spec_decoding flag is left False
so subsequent drafter forwards ignore the dynamic-tree metadata; set
attn_metadata.use_spec_decoding = True at the end of this preparation block
(after updating kv_lens_cuda, _seq_lens, host_request_types and before leaving
the dynamic-tree growth steps) so the next draft pass uses speculative decoding;
locate the block updating attn_metadata.kv_lens_cuda, attn_metadata._seq_lens,
attn_metadata.host_request_types and set attn_metadata.use_spec_decoding = True
there (ensure this happens before
spec_metadata.eagle3_resource_manager.is_first_draft is toggled).
- Around line 961-975: spec_decoding_position_offsets is being treated as a flat
vector but it’s stored as a 2-D buffer ([max_num_requests,
max_total_draft_tokens+1]); update the code to slice and assign it as 2-D so
rows correspond to requests: read previous_position_offsets =
attn_metadata.spec_decoding_position_offsets[:batch_size,
:num_tokens_previous_layer], build new_position_offsets by concatenating along
dim=1 (using previous_position_offsets and previous_position_offsets[:,
-self.dynamic_tree_max_topK:]+1), then write it back to
attn_metadata.spec_decoding_position_offsets[:batch_size,
:num_tokens_current_layer] (no flattening/view needed) so the correct request
rows are updated.

In `@tensorrt_llm/_torch/speculative/dynamic_tree_ops.py`:
- Around line 1-12: Add the standard NVIDIA Apache-2.0 license header (with the
latest modification year) at the top of the file before the existing module
docstring in dynamic_tree_ops.py; replace the current file-starting
docstring-only content by prepending the required NVIDIA copyright/license block
so the file begins with the Apache 2.0 header followed by the existing "Dynamic
Tree Operations for EAGLE3 Speculative Decoding" docstring.

In `@tensorrt_llm/_torch/speculative/eagle3_dynamic_tree.py`:
- Around line 98-143: The buffers currently fix max_batch_size = 256 which can
overflow for larger deployments; replace the hard-coded max_batch_size with a
runtime value derived from spec_config or the worker's actual max concurrent
sequences (e.g., spec_config.max_batch_size or a passed-in parameter) and use
that variable when allocating dt_draft_tokens_buffer, dt_position_ids_buffer,
history_draft_tokens_buffer, history_score_buffer,
history_draft_tokens_parent_buffer, tree_mask_buffer, tree_mask_init_buffer, and
tree_mask_padding_zeros, and when calling create_dynamic_tree_ops_converter
(preserve device and dtypes). Ensure the new max_batch_size value is validated
(positive int) and that any flattened-size calculations (like tree_mask_buffer
shape) are updated to use the dynamic max_batch_size to avoid out-of-bounds
writes.
- Around line 475-487: The code incorrectly reshapes
spec_decoding_position_offsets into a flat (max_reqs, tokens_per_req) layout
which conflicts with the request-major layout used elsewhere; replace the manual
flattening with a request-major view and index into the existing first
dimension. Concretely, stop computing max_reqs = total_po_size // tokens_per_req
and using pos_2d = attn_metadata.spec_decoding_position_offsets.view(max_reqs,
tokens_per_req); instead treat pos_2d as request-major (e.g., pos_2d =
attn_metadata.spec_decoding_position_offsets.view(-1, tokens_per_req) or simply
use the existing first-dimension shape) and then write pos_2d[req_idx, :n] =
causal_offs[:n] ensuring req_idx is computed as num_contexts + g_idx and within
bounds. Apply the same fix to the other occurrence around lines 1090-1101 to
preserve the [max_num_requests, max_total_draft_tokens + 1] layout everywhere.

In `@tensorrt_llm/_torch/speculative/model_drafter.py`:
- Around line 686-690: Remove the two unused CPU buffer assignments to avoid
dead code: delete the assignments to topk_score_indices and
history_draft_tokens_parent_buffer that read from
dynamic_tree_buffers["topk_score_indices"].cpu() and
dynamic_tree_buffers["history_draft_tokens_parent_buffer"].cpu(). If those
buffers are intended for future use, replace each assignment with a short TODO
comment referencing the buffer name (topk_score_indices,
history_draft_tokens_parent_buffer) and why it will be needed; otherwise simply
remove the two lines. Ensure no other code in the same method depends on these
variables after removal.
- Around line 1004-1005: The prepare_draft_tokens method in ModelDrafter
currently requires resource_manager but the base class Drafter defines it as
optional; change the signature of ModelDrafter.prepare_draft_tokens to accept
resource_manager: Optional[ResourceManager] = None so it matches the base
contract, add or ensure Optional is imported from typing if missing, and mirror
the pattern used in ngram.py; update any internal usage of resource_manager to
handle None safely.

In `@tensorrt_llm/llmapi/llm_args.py`:
- Around line 1865-1876: The file contains a duplicate TypeAlias named
SpeculativeConfig that shadows the earlier discriminated union (the one defined
with Field(discriminator="decoding_type")), which removes SADecodingConfig and
PARDDecodingConfig; remove the second SpeculativeConfig definition (or rename it
if you truly need a separate non-discriminated alias) and keep the original
annotated union (including SADecodingConfig and PARDDecodingConfig alongside
DraftTargetDecodingConfig, Eagle3DecodingConfig, EagleDecodingConfig,
LookaheadDecodingConfig, MedusaDecodingConfig, MTPDecodingConfig,
NGramDecodingConfig, UserProvidedDecodingConfig, SaveHiddenStatesDecodingConfig,
AutoDecodingConfig) so the discriminator-based Pydantic union remains intact.

---

Nitpick comments:
In `@tensorrt_llm/_torch/attention_backend/interface.py`:
- Around line 371-381: The function update_spec_dec_param currently accepts many
trailing boolean/optional parameters positionally; make is_target_model,
model_is_wrapped, and spec_tree_manager keyword-only by inserting a bare *
before is_target_model in the signature so callers cannot accidentally bind
those flags positionally—update the signature in the update_spec_dec_param
definition and adjust any internal references accordingly (no other logic
changes).

In `@tensorrt_llm/_torch/attention_backend/trtllm.py`:
- Around line 502-509: The reshape from 1D to 2D (position_offsets_for_cpp based
on self.spec_decoding_position_offsets) assumes the 1D tensor length equals
max_num_requests * N; add a defensive check before the view to assert that
self.spec_decoding_position_offsets.numel() is divisible by
self.max_num_requests and (optionally) equals self.max_num_requests *
(self.max_total_draft_tokens + 1) (or raise a clear error mentioning
spec_decoding_position_offsets and max_num_requests) so the implicit
allocation/reshape contract in trtllm.py is explicit and fails fast when
violated.

In `@tensorrt_llm/_torch/pyexecutor/py_executor_creator.py`:
- Around line 435-437: Replace the direct class imports
DynamicTreeDraftingLoopWrapper, LinearDraftingLoopWrapper,
StaticTreeDraftingLoopWrapper with a namespaced module import for drafting_loops
and update all references to use drafting_loops.DynamicTreeDraftingLoopWrapper,
drafting_loops.LinearDraftingLoopWrapper, and
drafting_loops.StaticTreeDraftingLoopWrapper (e.g., where these classes are used
in the dispatch/registration code inside py_executor_creator.py) so the module
is imported, not individual classes.

In `@tensorrt_llm/_torch/pyexecutor/sampler.py`:
- Around line 2626-2668: The loop in _process_draft_tokens_dynamic_tree
repeatedly calls accept_index[j].item(), causing device reads per-iteration;
materialize accept_index to a Python list once before the request loop (e.g.
accept_indices = accept_index.tolist() or accept_index.cpu().tolist() and cast
elements to int), then iterate over accept_indices for add_token and
finish_if_reason calls, and compute request.py_num_accepted_draft_tokens_indices
from that list by subtracting 1 for positions after the root; keep using the
same symbols (accept_index -> accept_indices,
_process_draft_tokens_dynamic_tree, add_token, finish_if_reason,
request.py_num_accepted_draft_tokens_indices) so you only replace tensor
indexing with list indexing.

In `@tests/unittest/_torch/speculative/test_draft_token_tree_sampling.py`:
- Around line 8-9: Replace the direct class import with a module-level import
for the drafting_loops module and qualify the class through that namespace:
change the current "from tensorrt_llm._torch.speculative.drafting_loops import
StaticTreeDraftingLoopWrapper" to "import
tensorrt_llm._torch.speculative.drafting_loops as drafting_loops" (or "from
tensorrt_llm._torch.speculative import drafting_loops") and update all usages to
drafting_loops.StaticTreeDraftingLoopWrapper (also fix the similar import/usage
at the other occurrence noted).

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 388f70bd-848c-4020-99de-5838cd97e5b3

📥 Commits

Reviewing files that changed from the base of the PR and between 3139ffa and b15f7d6.

📒 Files selected for processing (23)
  • cpp/tensorrt_llm/kernels/speculativeDecoding/dynamicTreeKernels.cu
  • cpp/tensorrt_llm/kernels/speculativeDecoding/dynamicTreeKernels.h
  • cpp/tensorrt_llm/thop/CMakeLists.txt
  • cpp/tensorrt_llm/thop/dynamicTreeOp.cpp
  • examples/llm-api/quickstart_advanced.py
  • tensorrt_llm/_torch/attention_backend/interface.py
  • tensorrt_llm/_torch/attention_backend/sparse/dsa.py
  • tensorrt_llm/_torch/attention_backend/trtllm.py
  • tensorrt_llm/_torch/models/modeling_speculative.py
  • tensorrt_llm/_torch/pyexecutor/model_engine.py
  • tensorrt_llm/_torch/pyexecutor/py_executor.py
  • tensorrt_llm/_torch/pyexecutor/py_executor_creator.py
  • tensorrt_llm/_torch/pyexecutor/sampler.py
  • tensorrt_llm/_torch/speculative/drafting_loops.py
  • tensorrt_llm/_torch/speculative/dynamic_tree_ops.py
  • tensorrt_llm/_torch/speculative/eagle3.py
  • tensorrt_llm/_torch/speculative/eagle3_dynamic_tree.py
  • tensorrt_llm/_torch/speculative/model_drafter.py
  • tensorrt_llm/_torch/speculative/spec_tree_manager.py
  • tensorrt_llm/_torch/speculative/utils.py
  • tensorrt_llm/llmapi/llm_args.py
  • tests/unittest/_torch/speculative/test_draft_token_prepare_for_generation.py
  • tests/unittest/_torch/speculative/test_draft_token_tree_sampling.py

Comment on lines +362 to +377
for i, prompt in enumerate(prompts):
num_tokens = 0
num_iterations = 0
for output in llm.generate_async(prompt,
sampling_params,
streaming=True):
new_tokens = output.outputs[0].token_ids
num_tokens = len(new_tokens)
num_iterations += 1
if num_iterations > 0:
accept_rate = num_tokens / num_iterations
print(f"[{i}] Accept rate: {accept_rate:.2f} "
f"(tokens={num_tokens}, iterations={num_iterations})")
generated_text = output.outputs[0].text
print(
f"[{i}] Prompt: {prompt!r}, Generated text: {generated_text!r}")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Make the streaming branch handle all sequences and empty streams.

This path assumes output.outputs[0] is the only sequence and that the iterator always yields at least once. That means --n > 1 / beam-search results are silently dropped here, and an empty stream would hit output before assignment on the final print. Either iterate over every output.outputs entry like the non-streaming path does, or explicitly reject those combinations in streaming mode.

Comment on lines +475 to +487
tokens_per_req = spec_metadata.max_total_draft_tokens + 1
total_po_size = attn_metadata.spec_decoding_position_offsets.shape[0]
max_reqs = total_po_size // tokens_per_req
pos_2d = attn_metadata.spec_decoding_position_offsets.view(
max_reqs, tokens_per_req
)
max_gl = int(gen_sl.max().item())
causal_offs = torch.arange(max_gl, device="cuda", dtype=torch.int32)
for g_idx in range(num_gens):
req_idx = num_contexts + g_idx
n = int(gen_sl[g_idx].item())
pos_2d[req_idx, :n] = causal_offs[:n]

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Keep spec_decoding_position_offsets in request-major layout.

This file mixes two incompatible views of spec_decoding_position_offsets: step 0 reshapes it as if it were flat, and the later update path slices it as [: batch_size * num_tokens_previous_layer]. If the metadata keeps the same [max_num_requests, max_total_draft_tokens + 1] layout used elsewhere in the speculative stack, both blocks will update the wrong memory once multiple requests are active.

Also applies to: 1090-1101

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/_torch/speculative/eagle3_dynamic_tree.py` around lines 475 -
487, The code incorrectly reshapes spec_decoding_position_offsets into a flat
(max_reqs, tokens_per_req) layout which conflicts with the request-major layout
used elsewhere; replace the manual flattening with a request-major view and
index into the existing first dimension. Concretely, stop computing max_reqs =
total_po_size // tokens_per_req and using pos_2d =
attn_metadata.spec_decoding_position_offsets.view(max_reqs, tokens_per_req);
instead treat pos_2d as request-major (e.g., pos_2d =
attn_metadata.spec_decoding_position_offsets.view(-1, tokens_per_req) or simply
use the existing first-dimension shape) and then write pos_2d[req_idx, :n] =
causal_offs[:n] ensuring req_idx is computed as num_contexts + g_idx and within
bounds. Apply the same fix to the other occurrence around lines 1090-1101 to
preserve the [max_num_requests, max_total_draft_tokens + 1] layout everywhere.

Comment on lines +686 to +690
topk_score_indices = dynamic_tree_buffers["topk_score_indices"].cpu(
) # [batch_size, self.max_total_draft_tokens]
history_draft_tokens_parent_buffer = dynamic_tree_buffers[
"history_draft_tokens_parent_buffer"].cpu(
) # [batch_size, dynamic_tree_max_topK + dynamic_tree_max_topK * dynamic_tree_max_topK * (max_draft_len - 1)]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Remove unused variables flagged by static analysis.

The variables topk_score_indices and history_draft_tokens_parent_buffer are assigned but never used in this method. If they are intended for future use, consider adding a TODO comment; otherwise, remove them to avoid confusion.

🧹 Proposed fix
         if isinstance(
                 self.spec_config,
                 EagleDecodingConfig) and self.spec_config.use_dynamic_tree:
             dynamic_tree_buffers = outputs["dynamic_tree_buffers"]
-            topk_score_indices = dynamic_tree_buffers["topk_score_indices"].cpu(
-            )  # [batch_size, self.max_total_draft_tokens]
-            history_draft_tokens_parent_buffer = dynamic_tree_buffers[
-                "history_draft_tokens_parent_buffer"].cpu(
-                )  # [batch_size, dynamic_tree_max_topK + dynamic_tree_max_topK * dynamic_tree_max_topK * (max_draft_len - 1)]
+            # Note: topk_score_indices and history_draft_tokens_parent_buffer are available
+            # in dynamic_tree_buffers if needed for debugging or future use
             tree_structure = dynamic_tree_buffers.get("tree_structure")
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
topk_score_indices = dynamic_tree_buffers["topk_score_indices"].cpu(
) # [batch_size, self.max_total_draft_tokens]
history_draft_tokens_parent_buffer = dynamic_tree_buffers[
"history_draft_tokens_parent_buffer"].cpu(
) # [batch_size, dynamic_tree_max_topK + dynamic_tree_max_topK * dynamic_tree_max_topK * (max_draft_len - 1)]
# Note: topk_score_indices and history_draft_tokens_parent_buffer are available
# in dynamic_tree_buffers if needed for debugging or future use
🧰 Tools
🪛 Ruff (0.15.5)

[error] 686-686: Local variable topk_score_indices is assigned to but never used

Remove assignment to unused variable topk_score_indices

(F841)


[error] 688-688: Local variable history_draft_tokens_parent_buffer is assigned to but never used

Remove assignment to unused variable history_draft_tokens_parent_buffer

(F841)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/_torch/speculative/model_drafter.py` around lines 686 - 690,
Remove the two unused CPU buffer assignments to avoid dead code: delete the
assignments to topk_score_indices and history_draft_tokens_parent_buffer that
read from dynamic_tree_buffers["topk_score_indices"].cpu() and
dynamic_tree_buffers["history_draft_tokens_parent_buffer"].cpu(). If those
buffers are intended for future use, replace each assignment with a short TODO
comment referencing the buffer name (topk_score_indices,
history_draft_tokens_parent_buffer) and why it will be needed; otherwise simply
remove the two lines. Ensure no other code in the same method depends on these
variables after removal.

Comment on lines +1004 to +1005
def prepare_draft_tokens(self, scheduled_requests: ScheduledRequests,
resource_manager: ResourceManager) -> None:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Description: Check if other implementations of prepare_draft_tokens have compatible signatures

# Find all prepare_draft_tokens method definitions
rg -n "def prepare_draft_tokens" --type=py -A 3

# Check the base class signature
cat tensorrt_llm/_torch/speculative/drafter.py | head -50

Repository: NVIDIA/TensorRT-LLM

Length of output: 3502


Fix LSP violation: make resource_manager parameter optional in ModelDrafter.prepare_draft_tokens().

The method signature at lines 1004-1005 requires resource_manager: ResourceManager, but the abstract base class at drafter.py:27-30 defines it as Optional[ResourceManager] = None. This violates Liskov Substitution Principle—callers cannot treat ModelDrafter instances polymorphically as Drafter without potentially failing if resource_manager is omitted. The ngram.py implementation maintains the correct signature. Update the method signature to match the base class contract: resource_manager: Optional[ResourceManager] = None.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/_torch/speculative/model_drafter.py` around lines 1004 - 1005,
The prepare_draft_tokens method in ModelDrafter currently requires
resource_manager but the base class Drafter defines it as optional; change the
signature of ModelDrafter.prepare_draft_tokens to accept resource_manager:
Optional[ResourceManager] = None so it matches the base contract, add or ensure
Optional is imported from typing if missing, and mirror the pattern used in
ngram.py; update any internal usage of resource_manager to handle None safely.

Comment on lines +971 to +990
if self.use_dynamic_tree or self.dynamic_tree_max_topK is not None:
self.use_dynamic_tree = True
assert self.dynamic_tree_max_topK is not None and self.dynamic_tree_max_topK > 0, "dynamic_tree_max_topK is required for dynamic tree"
assert self.eagle_choices is None, "If use_dynamic_tree is True, eagle_choices should be None"
total_history_draft_tokens = self.dynamic_tree_max_topK + self.dynamic_tree_max_topK * self.dynamic_tree_max_topK * (
self.max_draft_len - 1)
default_max_total_draft_tokens = self.dynamic_tree_max_topK * self.max_draft_len

if self.max_total_draft_tokens is None:
self.max_total_draft_tokens = default_max_total_draft_tokens
logger.warning(
f"max_total_draft_tokens is not provided, use the default value {default_max_total_draft_tokens} (default_max_total_draft_tokens = dynamic_tree_max_topK * max_draft_len)"
)
else:
assert self.max_total_draft_tokens <= total_history_draft_tokens and self.max_total_draft_tokens >= default_max_total_draft_tokens, f"max_total_draft_tokens should be between {default_max_total_draft_tokens} and {total_history_draft_tokens}"

# Linear tree
if self.max_total_draft_tokens is None:
self.max_total_draft_tokens = self.max_draft_len

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

# Get line count first to understand file size
wc -l tensorrt_llm/llmapi/llm_args.py

Repository: NVIDIA/TensorRT-LLM

Length of output: 99


🏁 Script executed:

# Read the lines around 953-990 to see the full context of the dynamic tree validation
sed -n '940,1000p' tensorrt_llm/llmapi/llm_args.py

Repository: NVIDIA/TensorRT-LLM

Length of output: 3558


🏁 Script executed:

# Check the SpeculativeConfig definitions around lines 1847-1876
sed -n '1840,1890p' tensorrt_llm/llmapi/llm_args.py

Repository: NVIDIA/TensorRT-LLM

Length of output: 1502


🏁 Script executed:

# Check the quickstart_advanced.py file mentioned in the scratchpad
sed -n '240,260p' examples/llm-api/quickstart_advanced.py

Repository: NVIDIA/TensorRT-LLM

Length of output: 1127


Replace assert statements with ValueError and consolidate dynamic-tree validation logic.

The unreachable defaulting at line 981 confirms the existing validation at line 962 will reject omitted max_total_draft_tokens before the fallback runs. Additionally, lines 986-987 and 989 use assert statements that violate the coding guideline requiring ValueError to be raised in Pydantic validators. Consolidate the two dynamic-tree blocks into a single validator, remove the assert statements, and handle both the validation and the new defaulting logic within one flow path.

Comment on lines +1865 to +1876
SpeculativeConfig: TypeAlias = Optional[Union[
DraftTargetDecodingConfig,
Eagle3DecodingConfig, # Must be before EagleDecodingConfig since it's a subclass
EagleDecodingConfig,
LookaheadDecodingConfig,
MedusaDecodingConfig,
MTPDecodingConfig,
NGramDecodingConfig,
UserProvidedDecodingConfig,
SaveHiddenStatesDecodingConfig,
AutoDecodingConfig,
]]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

wc -l tensorrt_llm/llmapi/llm_args.py

Repository: NVIDIA/TensorRT-LLM

Length of output: 99


🏁 Script executed:

sed -n '1850,1900p' tensorrt_llm/llmapi/llm_args.py | cat -n

Repository: NVIDIA/TensorRT-LLM

Length of output: 2003


🏁 Script executed:

rg "SpeculativeConfig\s*:" tensorrt_llm/llmapi/llm_args.py -n

Repository: NVIDIA/TensorRT-LLM

Length of output: 161


🏁 Script executed:

rg "speculative_config" tensorrt_llm/llmapi/llm_args.py -n -A 3

Repository: NVIDIA/TensorRT-LLM

Length of output: 6052


🏁 Script executed:

rg "class.*DecodingConfig" tensorrt_llm/llmapi/llm_args.py -n | head -20

Repository: NVIDIA/TensorRT-LLM

Length of output: 771


🏁 Script executed:

sed -n '1840,1880p' tensorrt_llm/llmapi/llm_args.py | cat -n

Repository: NVIDIA/TensorRT-LLM

Length of output: 1554


🏁 Script executed:

sed -n '1847,1863p' tensorrt_llm/llmapi/llm_args.py | cat -n

Repository: NVIDIA/TensorRT-LLM

Length of output: 656


Remove the second SpeculativeConfig definition or rename it.

The second SpeculativeConfig definition (lines 1865-1876) shadows the first (lines 1847-1863), removing the discriminator field and excluding SADecodingConfig and PARDDecodingConfig. This violates the guideline to use discriminated unions in Pydantic for fields accepting multiple config classes. Keep the first annotated definition with Field(discriminator="decoding_type") and either remove the second definition or give it a different name if an optional non-discriminated alias is needed elsewhere.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/llmapi/llm_args.py` around lines 1865 - 1876, The file contains
a duplicate TypeAlias named SpeculativeConfig that shadows the earlier
discriminated union (the one defined with Field(discriminator="decoding_type")),
which removes SADecodingConfig and PARDDecodingConfig; remove the second
SpeculativeConfig definition (or rename it if you truly need a separate
non-discriminated alias) and keep the original annotated union (including
SADecodingConfig and PARDDecodingConfig alongside DraftTargetDecodingConfig,
Eagle3DecodingConfig, EagleDecodingConfig, LookaheadDecodingConfig,
MedusaDecodingConfig, MTPDecodingConfig, NGramDecodingConfig,
UserProvidedDecodingConfig, SaveHiddenStatesDecodingConfig, AutoDecodingConfig)
so the discriminator-based Pydantic union remains intact.

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #38374 [ run ] completed with state FAILURE. Commit: b15f7d6
/LLM/main/L0_MergeRequest_PR pipeline #29741 completed with status: 'FAILURE'

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@sunnyqgg sunnyqgg force-pushed the add_dyanmic_tree_support_one_model branch from 99b25a9 to c2c8ef6 Compare March 12, 2026 01:52
@sunnyqgg
Copy link
Copy Markdown
Collaborator Author

/bot run

@sunnyqgg sunnyqgg force-pushed the add_dyanmic_tree_support_one_model branch from c2c8ef6 to abd7543 Compare March 15, 2026 07:36
@sunnyqgg
Copy link
Copy Markdown
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #38968 [ run ] triggered by Bot. Commit: abd7543 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #38968 [ run ] completed with state FAILURE. Commit: abd7543
/LLM/main/L0_MergeRequest_PR pipeline #30250 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@sunnyqgg
Copy link
Copy Markdown
Collaborator Author

/bot run

1 similar comment
@sunnyqgg
Copy link
Copy Markdown
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #39082 [ run ] triggered by Bot. Commit: bb00556 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #39082 [ run ] completed with state FAILURE. Commit: bb00556
/LLM/main/L0_MergeRequest_PR pipeline #30345 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@sunnyqgg
Copy link
Copy Markdown
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #39197 [ run ] triggered by Bot. Commit: a79c91d Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #40938 [ run ] completed with state SUCCESS. Commit: d4a9537
/LLM/main/L0_MergeRequest_PR pipeline #31930 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

Copy link
Copy Markdown
Collaborator

@laikhtewari laikhtewari left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There should be documentation on how to use this in the speculative decoding feature page

The feature combination matrix should also be updated

@sunnyqgg
Copy link
Copy Markdown
Collaborator Author

sunnyqgg commented Apr 1, 2026

/bot run

@sunnyqgg
Copy link
Copy Markdown
Collaborator Author

sunnyqgg commented Apr 1, 2026

There should be documentation on how to use this in the speculative decoding feature page

The feature combination matrix should also be updated

Hi @laikhtewari lai , I updated the doc, and feature combination matrix is already coverd

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #41072 [ run ] triggered by Bot. Commit: c6a30b9 Link to invocation

@sunnyqgg sunnyqgg changed the title [None][feat] Add EAGLE3 dynamic tree speculative decoding support [TRTLLM-11540][feat] Add EAGLE3 dynamic tree speculative decoding support Apr 1, 2026
@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #41072 [ run ] completed with state FAILURE. Commit: c6a30b9
/LLM/main/L0_MergeRequest_PR pipeline #32048 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

…e override decorators

Signed-off-by: qgai <qgai@nvidia.com>
…ch_size on Eagle3DecodingConfig

Signed-off-by: qgai <qgai@nvidia.com>
…for EAGLE3 dynamic tree

Signed-off-by: qgai <qgai@nvidia.com>
…ailures

Provide valid eagle_choices for static-tree SpecTreeManager on H100 (sm<100)
to avoid TypeError when iterating None. Relax logits tolerance from 0.4 to
1.0 on B200 since greedy argmax match is the real correctness gate.

Signed-off-by: qgai <qgai@nvidia.com>
…est + add docs

Switch SpecTreeManager in test_llama_verification_with_kv_cache_relocation
from static tree (use_dynamic_tree=False) to dynamic tree mode, removing
the eagle_choices parameter. Fixes RuntimeError on H100 (sm<100) where
flat single-level eagle_choices produced empty top_k_list tensors.

Also add EAGLE3 dynamic tree mode documentation to speculative-decoding.md
per reviewer request.

Signed-off-by: qgai <qgai@nvidia.com>
@sunnyqgg sunnyqgg force-pushed the add_dyanmic_tree_support_one_model branch from c6a30b9 to 34a3f35 Compare April 1, 2026 04:40
@sunnyqgg
Copy link
Copy Markdown
Collaborator Author

sunnyqgg commented Apr 1, 2026

/bot run

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #41105 [ run ] triggered by Bot. Commit: 34a3f35 Link to invocation

_update_kv_cache_draft_token_location(self, scheduled_batch,
attn_metadata,
kv_cache_dtype_byte_size)

Copy link
Copy Markdown
Collaborator

@venkywonka venkywonka Apr 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i don't know anything about dynamic tree specdec - but seeing this potential bug raised by devin: https://app.devin.ai/review/NVIDIA/TensorRT-LLM/pull/12062#bug-BUG_pr-review-job-d1faac1ad51d4ae4b06c1576a9c56332_0001

tensorrt_llm/_torch/pyexecutor/resource_manager.py:R807-812

Removal of _update_kv_cache_draft_token_location from update_resources breaks two-model spec-dec flows

The call to _update_kv_cache_draft_token_location was removed from KVCacheManager.update_resources() at tensorrt_llm/_torch/pyexecutor/resource_manager.py:801-805. This function compacted accepted draft tokens' KV cache entries into contiguous positions after verification. While the dynamic-tree one-model flow replaces this with _relocate_kv_eagerly inside the worker, all two-model speculative decoding flows (EAGLE3 two-model, MTP two-model, DraftTarget) relied on this call path and now silently skip KV cache relocation. This causes stale/incorrect KV cache entries to remain in subsequent decoding iterations, leading to wrong outputs. The function _update_kv_cache_draft_token_location is now effectively dead code since it has no remaining callers.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi,
all two-model speculative decoding flows (EAGLE3 two-model, MTP two-model, DraftTarget) relied on this call path and now silently skip KV cache relocation=====》all of them don't need to run _update_kv_cache_draft_token_location so I delete it directly

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #41105 [ run ] completed with state FAILURE. Commit: 34a3f35
/LLM/main/L0_MergeRequest_PR pipeline #32079 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@sunnyqgg
Copy link
Copy Markdown
Collaborator Author

sunnyqgg commented Apr 1, 2026

/bot run

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #41145 [ run ] triggered by Bot. Commit: 34a3f35 Link to invocation

Copy link
Copy Markdown
Collaborator

@eopXD eopXD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. Regarding _relocate_kv_eagerly, I think we should still keep KV cache management workflow concentrated under resource_manager.py. This should be the same goal as _update_kv_cache_draft_token_location as you have removed, but for 2D KV which you have implemented. Is it possible for you to add new methods under resource_manager.py and update how we deal with draft token locations under update_resources? We also need to make sure that all existing cases are covered with the removal of _update_kv_cache_draft_token_location.

  2. Under _update_kv_cache_draft_token_location, thank you for the fix of specifying the correct number of kv_heads for TP. The fix has an implicit assumption that TP shards are uniform. We can resolve this by adding a function that returns the correct TP number. This gives us the flexibility and avoids future bug.

  3. We should add unit test for the 2D KV relocation kernel.

cache_mgr.num_kv_heads_per_layer[0],
self._kv_head_dim_bytes,
cache_mgr.max_total_draft_tokens,
cache_mgr.max_attention_window_vec[0],
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this mean we don't support VSWA now?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes

cache_mgr.num_layers,
# Use TP-sharded num_kv_heads (per-rank) instead of the unsharded
# total so the C++ kernel computes correct strides and grid dims.
cache_mgr.num_kv_heads_per_layer[0],
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto, wrap a function around this to free ourself from the implicit assumption.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I add an assert for it

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #41145 [ run ] completed with state SUCCESS. Commit: 34a3f35
/LLM/main/L0_MergeRequest_PR pipeline #32114 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@sunnyqgg
Copy link
Copy Markdown
Collaborator Author

sunnyqgg commented Apr 1, 2026

  1. Regarding _relocate_kv_eagerly, I think we should still keep KV cache management workflow concentrated under resource_manager.py. This should be the same goal as _update_kv_cache_draft_token_location as you have removed, but for 2D KV which you have implemented. Is it possible for you to add new methods under resource_manager.py and update how we deal with draft token locations under update_resources? We also need to make sure that all existing cases are covered with the removal of _update_kv_cache_draft_token_location.
  2. Under _update_kv_cache_draft_token_location, thank you for the fix of specifying the correct number of kv_heads for TP. The fix has an implicit assumption that TP shards are uniform. We can resolve this by adding a function that returns the correct TP number. This gives us the flexibility and avoids future bug.
  3. We should add unit test for the 2D KV relocation kernel.
  1. Regarding _relocate_kv_eagerly, I think we should still keep KV cache management workflow concentrated under resource_manager.py.====》Thanks,the timing of it is different from _update_kv_cache_draft_token_location, it's supposed to be done here due to overlap scheduler, and '_update_kv_cache_draft_token_location' will be deleted soon as it's only used by static tree which will be deprecated soon. And it's done during cuda graph

2.Thanks, but I didn't totally get it for this question, but I think we should leave the computation of cache_mgr.num_kv_heads_per_layer[0] in KVCacheManager and the update_kv_cache_draft_token_location_2d only supports one case, so I add an assert for it.

  1. Thanks, done

…2D paged KV cache test

- Add assertion before update_kv_cache_draft_token_location and
  update_kv_cache_draft_token_location_2d calls to validate that
  num_kv_heads_per_layer is uniform across all layers, since the
  underlying C++ kernel takes a single scalar numKVHeads.
- Add unit test for the 2D paged KV cache draft token relocation kernel.
- Fix xqaDispatcher multiCtasKvMode and decoder info condition.

Signed-off-by: qgai <qgai@nvidia.com>
@sunnyqgg
Copy link
Copy Markdown
Collaborator Author

sunnyqgg commented Apr 1, 2026

/bot run

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #41172 [ run ] triggered by Bot. Commit: 6891eae Link to invocation

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants