Skip to content

Conversation

yeyu-nvidia
Copy link
Contributor

@yeyu-nvidia yeyu-nvidia commented Oct 13, 2025

What does this PR do?

Type of change: New feature

Overview:
This PR implements kv cache in eagle training that significantly reduce memory consumption so that enables longer sequence length.
This PR also contains a prototype of parallel draft implementation. However, the training does not converge. I will leave it as a future work.

Usage

# Add a code snippet demonstrating how to use this

Testing

Passed regression test.

Before your PR is "Ready for review"

  • Make sure you read and follow Contributor guidelines and your commits are signed.
  • Is this change backward compatible?: Yes/No
  • Did you write any new necessary tests?: Yes/No
  • Did you add or update any necessary documentation?: Yes/No
  • Did you update Changelog?: Yes/No

Additional Information

Summary by CodeRabbit

  • New Features

    • Parallel-draft speculative decoding with configurable multi-step processing and learnable draft embeddings/hidden states.
  • Refactor

    • Forwarding APIs now accept an optional inference context to propagate cache/offset info; per-step attention masking and sequence-offset tracking unified across base and draft flows.
  • Documentation

    • Notes clarifying kv-cache and sequence-parallel limitations when using parallel drafting.
  • Chores

    • Distributed data-consistency now warns and auto-aligns ranks by default (non-fatal).

Signed-off-by: Ye Yu <[email protected]>
Signed-off-by: Ye Yu <[email protected]>
Signed-off-by: Ye Yu <[email protected]>
Signed-off-by: Ye Yu <[email protected]>
Signed-off-by: Ye Yu <[email protected]>
Signed-off-by: Ye Yu <[email protected]>
Signed-off-by: Ye Yu <[email protected]>
Signed-off-by: Ye Yu <[email protected]>
…ach ttt step, only the non_parallel tokens from previous ttt are used as context

Signed-off-by: Ye Yu <[email protected]>
Signed-off-by: Ye Yu <[email protected]>
Signed-off-by: Ye Yu <[email protected]>
Signed-off-by: Ye Yu <[email protected]>
Signed-off-by: Ye Yu <[email protected]>
@yeyu-nvidia yeyu-nvidia requested a review from a team as a code owner October 13, 2025 17:50
@yeyu-nvidia yeyu-nvidia requested a review from ChenhanYu October 13, 2025 17:50
Copy link

coderabbitai bot commented Oct 13, 2025

Walkthrough

Threads a StaticInferenceContext through Megatron EAGLE forward paths, refactors per-step attention-mask construction, adds conditional parallel-draft embeddings/hidden-states and wiring, updates forward signatures and inference flow (including pseudo_speculative_generate), and relaxes cross-rank consistency checks to warn-and-coerce.

Changes

Cohort / File(s) Change summary
Megatron EAGLE speculative execution
modelopt/torch/speculative/plugins/megatron_eagle.py
Propagates StaticInferenceContext through Eagle and Megatron forward paths; reworks multi-step attention-mask construction into per-step segments; adds conditional parallel_draft_embeddings and parallel_draft_hidden_states and integrates them into embedding/hidden-state flows; updates _get_eagle_module_inputs, _eagle_forward, and forward signatures to accept ttt_step, parallel_draft_index, and inference_context; _eagle_forward now returns extra pre-norm hidden states and updates a kv-cache-like eagle_inference_context; refactors pseudo_speculative_generate to pad/replace with parallel-draft tensors and documents kv-cache/sequence-parallel limitations.
Distributed utils consistency behavior
modelopt/torch/speculative/utils.py
Changes default of check_data_consistency_across_ranks to fail_when_mismatch=False (warn-and-coerce by default), updates docstring and call sites to rely on non-fatal divergence handling instead of raising on mismatch.

Sequence Diagram(s)

sequenceDiagram
  autonumber
  participant Caller as MegatronEagle.forward
  participant Emb as Embeddings
  participant Eagle as EagleModule
  participant Ctx as StaticInferenceContext
  participant Mask as Mask/Cache Builder

  Caller->>Mask: build per-step attention masks, pos_ids, rotary
  alt parallel_draft_step > 1
    Caller->>Emb: use base + parallel_draft_embeddings / parallel_draft_hidden_states
  else
    Caller->>Emb: compute standard embeddings / hidden_states
  end
  Caller->>Ctx: init eagle_inference_context (capacity, offsets)
  loop for each ttt_step
    Caller->>Eagle: forward(inputs..., ttt_step, parallel_draft_index, inference_context=Ctx)
    Eagle-->>Caller: logits, hidden_states, pre-norm states
    Caller->>Ctx: update sequence_len_offset (kv-cache-like)
    Caller->>Mask: adjust caches/masks for next step
  end
  Caller-->>Caller: assemble drafts / return outputs
  note right of Caller: pseudo_speculative_generate pads and replaces drafts when parallel drafting enabled
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Poem

A whisk of bytes, I hop through beams,
Stitching steps and eagle dreams. 🦅
Drafts in parallel, masks aligned,
Context counts each token’s time.
Ranks now warn, not crash the show—
Carrots cached, and off I go! 🥕✨

Pre-merge checks and finishing touches

❌ Failed checks (1 inconclusive)
Check name Status Explanation Resolution
Title Check ❓ Inconclusive The title "Yeyu/debug paralllel draft" appears to be a branch name used as the PR title rather than a polished summary. While it is related to a real part of the changeset (parallel draft implementation is mentioned in the changes), it lacks clarity in several ways. The term "debug" is vague and non-descriptive, and critically, the title does not capture the primary feature described in the PR objectives: the KV cache implementation in Eagle training. Additionally, the title contains a typo ("paralllel" with three 'l's), suggesting it may not have been carefully considered as a proper PR title. A developer scanning the repository history would not understand from this title that the PR implements significant memory optimization features. Consider updating the title to more clearly reflect the main objectives of the PR. A better title might be something like "Implement KV cache in Eagle training with parallel draft prototype" or "Add KV cache support and parallel draft implementation for Eagle module," which would accurately convey both the primary feature (KV cache) and the secondary addition (parallel draft prototype). This would make the PR's purpose immediately clear to reviewers and future developers browsing the repository history.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Docstring Coverage ✅ Passed Docstring coverage is 80.00% which is sufficient. The required threshold is 80.00%.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch yeyu/debug_paralllel_draft

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
modelopt/torch/speculative/utils.py (1)

293-317: Enforce strict data‐consistency failure at call sites
All four calls (lines 336, 351, 356, 358 in modelopt/torch/speculative/utils.py) now use the default fail_when_mismatch=False, emitting only warnings on mismatch. Explicitly pass fail_when_mismatch=True at each call if divergence should remain fatal.

🧹 Nitpick comments (1)
modelopt/torch/speculative/plugins/megatron_eagle.py (1)

484-490: Consider initialization strategy for learnable draft parameters.

The parallel draft embeddings and hidden states are initialized with torch.rand, which produces values in [0, 1). For better training stability, consider:

  • Zero initialization: torch.zeros
  • Small random initialization: torch.randn(...) * 0.01
  • Xavier/Kaiming initialization based on the hidden size

Random uniform initialization in [0, 1) can lead to large gradients early in training.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 46a9e49 and bfdd637.

📒 Files selected for processing (2)
  • modelopt/torch/speculative/plugins/megatron_eagle.py (17 hunks)
  • modelopt/torch/speculative/utils.py (2 hunks)
🧰 Additional context used
🧠 Learnings (3)
📓 Common learnings
Learnt from: yeyu-nvidia
PR: NVIDIA/TensorRT-Model-Optimizer#391
File: modelopt/torch/speculative/plugins/megatron_eagle.py:1036-1040
Timestamp: 2025-10-01T21:34:30.854Z
Learning: In EAGLE parallel draft with KV cache (modelopt/torch/speculative/plugins/megatron_eagle.py), the StaticInferenceContext buffer size is correctly calculated as `input_ids.shape[1] * (parallel_draft_step + ttt_steps - 1)` because each ttt_step iteration discards (parallel_draft_step - 1) worth of KV entries by decrementing sequence_len_offset, so only one token's worth of KV cache is retained per ttt_step.
Learnt from: yeyu-nvidia
PR: NVIDIA/TensorRT-Model-Optimizer#391
File: modelopt/torch/speculative/plugins/megatron_eagle.py:1049-1071
Timestamp: 2025-09-29T17:40:37.310Z
Learning: In EAGLE parallel draft implementation in modelopt/torch/speculative/plugins/megatron_eagle.py, overwriting eagle_hidden_states_*_pre_norm variables within the per-draft loop (rather than accumulating all draft pre-norm states) is the expected behavior for parallel draft functionality.
📚 Learning: 2025-10-01T21:34:30.854Z
Learnt from: yeyu-nvidia
PR: NVIDIA/TensorRT-Model-Optimizer#391
File: modelopt/torch/speculative/plugins/megatron_eagle.py:1036-1040
Timestamp: 2025-10-01T21:34:30.854Z
Learning: In EAGLE parallel draft with KV cache (modelopt/torch/speculative/plugins/megatron_eagle.py), the StaticInferenceContext buffer size is correctly calculated as `input_ids.shape[1] * (parallel_draft_step + ttt_steps - 1)` because each ttt_step iteration discards (parallel_draft_step - 1) worth of KV entries by decrementing sequence_len_offset, so only one token's worth of KV cache is retained per ttt_step.

Applied to files:

  • modelopt/torch/speculative/plugins/megatron_eagle.py
📚 Learning: 2025-09-29T17:40:37.310Z
Learnt from: yeyu-nvidia
PR: NVIDIA/TensorRT-Model-Optimizer#391
File: modelopt/torch/speculative/plugins/megatron_eagle.py:1049-1071
Timestamp: 2025-09-29T17:40:37.310Z
Learning: In EAGLE parallel draft implementation in modelopt/torch/speculative/plugins/megatron_eagle.py, overwriting eagle_hidden_states_*_pre_norm variables within the per-draft loop (rather than accumulating all draft pre-norm states) is the expected behavior for parallel draft functionality.

Applied to files:

  • modelopt/torch/speculative/plugins/megatron_eagle.py
🧬 Code graph analysis (2)
modelopt/torch/speculative/utils.py (2)
modelopt/torch/speculative/plugins/megatron_eagle.py (1)
  • pseudo_speculative_generate (1332-1471)
modelopt/torch/speculative/plugins/transformers.py (1)
  • pseudo_speculative_generate (861-942)
modelopt/torch/speculative/plugins/megatron_eagle.py (2)
modelopt/torch/speculative/plugins/transformers.py (2)
  • _get_eagle_module_inputs (501-548)
  • _eagle_forward (632-657)
modelopt/torch/speculative/utils.py (1)
  • get_default_attention_mask_and_position_ids (49-66)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
  • GitHub Check: linux
  • GitHub Check: wait-checks / wait
  • GitHub Check: wait-checks / wait
  • GitHub Check: build-docs
  • GitHub Check: code-quality
🔇 Additional comments (15)
modelopt/torch/speculative/utils.py (1)

336-363: Verify behavior change at lines 336 and 358.

With the new default fail_when_mismatch=False, the calls to check_data_consistency_across_ranks at lines 336 and 358 will now warn and coerce data instead of raising an error on mismatch. This is an implicit behavior change that may not be obvious to maintainers.

If these call sites require strict consistency (e.g., for ground truth validation), explicitly pass fail_when_mismatch=True.

modelopt/torch/speculative/plugins/megatron_eagle.py (14)

29-29: LGTM!

The StaticInferenceContext import is correctly added to support KV cache functionality in EAGLE forward passes.


142-144: LGTM!

The docstring clarification improves accuracy.


190-277: LGTM: Dynamic multi-step mask construction.

The refactored implementation correctly generalizes the attention mask construction to handle an arbitrary number of steps, which is essential for the parallel draft feature. The loop-based approach (lines 267-276) builds mask segments dynamically and concatenates them.

The detailed docstring examples help understand the complex masking pattern for ttt_steps=2 and parallel_draft_step=3.


545-546: LGTM!

The inference_context parameter is correctly threaded through the Eagle module to enable KV cache support.

Also applies to: 586-586


788-839: LGTM: Multi-step draft input preparation.

The refactored logic correctly handles both regular (parallel_draft_index=0) and parallel draft (parallel_draft_index>0) scenarios:

  • For regular draft: uses actual token embeddings from input_ids
  • For parallel draft: uses learnable parameters that were initialized in EagleModule.init

The attention mask replication (line 830-832) and rotary embedding concatenation (line 834-837) align with the multi-step processing pattern.


931-958: LGTM: KV cache offset management and return value extension.

The changes correctly:

  1. Thread inference_context through to the Eagle module (line 941)
  2. Update sequence_len_offset after each Eagle forward pass (lines 946-947) to track KV cache consumption
  3. Return pre-final-layernorm hidden states (line 957) for use as input to subsequent iterations

The sequence_len_offset increment by input_ids.shape[1] correctly accounts for the tokens processed in this Eagle forward pass.

Based on learnings.


1014-1017: LGTM: Correct KV cache capacity calculation.

The StaticInferenceContext capacity is correctly calculated as input_ids.shape[1] * (parallel_draft_step + ttt_steps - 1). This accounts for the KV entries retained across ttt_step iterations, where each iteration discards (parallel_draft_step - 1) entries by decrementing sequence_len_offset.

Based on learnings.


1046-1063: LGTM: Label padding for offline training.

The label padding logic (lines 1046-1057) correctly handles cases where labels are one token shorter than input_ids in offline training scenarios. The comment on line 1050-1052 appropriately notes the small error introduced for the last token when not using logit distillation.


1065-1121: LGTM: Nested loop structure with correct KV cache management.

The nested loop structure correctly implements the parallel draft algorithm:

  1. Outer loop (line 1066): Iterates over ttt_steps
  2. Inner loop (line 1068): Iterates over parallel_draft_step
  3. KV cache offset adjustment (lines 1119-1121): Correctly decrements by input_ids.shape[1] * (parallel_draft_step - 1) after each ttt_step

Line 1088 correctly overwrites next_eagle_hidden_states_pre_norm on each inner iteration (only storing the i==0 result), which aligns with the expected behavior noted in the learnings.

Based on learnings.


1097-1114: Verify the hidden state padding logic.

The hidden state manipulation (lines 1097-1114) prepends a zero tensor and shifts the hidden states. Ensure this logic correctly aligns with the attention mask and position IDs for the next ttt_step iteration.

Consider adding a comment explaining why this padding/shifting is necessary for the next iteration.


1128-1157: LGTM: Loss computation and accuracy tracking.

The loss computation correctly:

  1. Slices logits into chunks of input_ids.shape[1] (lines 1129-1130)
  2. Applies decay factor based on ttt_step and parallel_draft_index (line 1132-1133)
  3. Accumulates weighted losses (line 1132-1134)

The accuracy tracking (lines 1136-1150) properly handles vocab size adjustments and reports only on rank 0 to avoid duplicate output.


1339-1340: LGTM: Accurate docstring update.

The docstring correctly notes that KV cache is not supported in this function when sequence parallel is enabled. This is an important limitation for users to be aware of.


1389-1423: LGTM: Parallel draft embedding replacement.

The implementation correctly:

  1. Pads dummy tokens and hidden states (lines 1389-1393)
  2. Gathers embeddings when sequence parallel is used (lines 1407-1410)
  3. Replaces dummy entries with learnable parallel_draft_embeddings and parallel_draft_hidden_states (lines 1414-1419)
  4. Scatters back if sequence parallel (lines 1420-1423)

The conditional handling of sequence_parallel ensures correctness in both modes.


1442-1467: LGTM: Multi-token draft extraction and cleanup.

The draft token extraction correctly:

  1. Extracts the last parallel_draft_step tokens from logits (lines 1443-1445)
  2. Applies vocab mapping if needed (lines 1449-1450)
  3. Removes dummy tokens before appending draft tokens (lines 1456-1458)
  4. Appends the correct number of hidden states for the next iteration (lines 1461-1467)

The cleanup logic ensures that dummy padding is removed before progressing to the next draft step.

Copy link

codecov bot commented Oct 13, 2025

Codecov Report

❌ Patch coverage is 33.33333% with 2 lines in your changes missing coverage. Please review.
✅ Project coverage is 73.37%. Comparing base (46a9e49) to head (a663526).
⚠️ Report is 9 commits behind head on main.

Files with missing lines Patch % Lines
modelopt/torch/speculative/utils.py 33.33% 2 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main     #429      +/-   ##
==========================================
- Coverage   73.38%   73.37%   -0.01%     
==========================================
  Files         180      180              
  Lines       17934    17937       +3     
==========================================
+ Hits        13160    13162       +2     
- Misses       4774     4775       +1     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
modelopt/torch/speculative/plugins/megatron_eagle.py (1)

197-284: NO dedicated unit tests found for set_multi_step_attention_mask—add tests to verify mask construction correctness.

The refactored function implements complex multi-step attention mask logic for TTT steps and parallel draft indices. While integration tests exist for the eagle module (in test_speculative_megatron_modules.py), they do not directly validate the correctness of this mask construction. The function is called in production code at lines 837-838 and is critical for correct attention computation.

Recommended actions:

  • Add unit tests for set_multi_step_attention_mask with specific test cases:
    • Verify output mask shape matches expected dimensions
    • Validate mask values (True/False) correspond to documented attention patterns in the docstring
    • Test edge cases: step=0, step=1, various batch and sequence lengths
    • Cross-check mask construction against the ASCII diagrams in the docstring for ttt_step=0 and ttt_step=1
🧹 Nitpick comments (3)
modelopt/torch/speculative/plugins/megatron_eagle.py (3)

150-150: Minor: Docstring wording could be more consistent with function name.

The docstring says "greater than" but the function is named mcore_version_higher_than. Consider using "higher than" to match the function name, or this is acceptable as-is.


490-497: Use proper initialization strategy for learnable parameters.

The parallel draft embeddings and hidden states are initialized with torch.rand, which produces values uniformly distributed in [0, 1). This is not ideal for neural network parameters and could contribute to convergence issues.

Consider using proper initialization strategies like Xavier/Kaiming initialization:

-        if config.parallel_draft_step > 1:
-            self.parallel_draft_embeddings = torch.nn.Parameter(
-                torch.rand(config.parallel_draft_step - 1, config.hidden_size)
-            )
-            self.parallel_draft_hidden_states = torch.nn.Parameter(
-                torch.rand(config.parallel_draft_step - 1, config.hidden_size)
-            )
+        if config.parallel_draft_step > 1:
+            self.parallel_draft_embeddings = torch.nn.Parameter(
+                torch.zeros(config.parallel_draft_step - 1, config.hidden_size)
+            )
+            torch.nn.init.xavier_uniform_(self.parallel_draft_embeddings)
+            self.parallel_draft_hidden_states = torch.nn.Parameter(
+                torch.zeros(config.parallel_draft_step - 1, config.hidden_size)
+            )
+            torch.nn.init.xavier_uniform_(self.parallel_draft_hidden_states)

This could help address convergence issues mentioned in the PR description.


1396-1430: Consider preallocating parallel draft embeddings instead of padding then replacing.

The current approach pads dummy tokens and hidden states (lines 1396-1400) then replaces them with parallel draft values (lines 1419-1426). This is inefficient and adds complexity.

Consider directly constructing the tensors with the correct values, or building a helper function that handles this logic more cleanly.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between bfdd637 and fc4551a.

📒 Files selected for processing (1)
  • modelopt/torch/speculative/plugins/megatron_eagle.py (18 hunks)
🧰 Additional context used
🧠 Learnings (3)
📓 Common learnings
Learnt from: yeyu-nvidia
PR: NVIDIA/TensorRT-Model-Optimizer#391
File: modelopt/torch/speculative/plugins/megatron_eagle.py:1049-1071
Timestamp: 2025-09-29T17:40:37.310Z
Learning: In EAGLE parallel draft implementation in modelopt/torch/speculative/plugins/megatron_eagle.py, overwriting eagle_hidden_states_*_pre_norm variables within the per-draft loop (rather than accumulating all draft pre-norm states) is the expected behavior for parallel draft functionality.
Learnt from: yeyu-nvidia
PR: NVIDIA/TensorRT-Model-Optimizer#391
File: modelopt/torch/speculative/plugins/megatron_eagle.py:1036-1040
Timestamp: 2025-10-01T21:34:30.854Z
Learning: In EAGLE parallel draft with KV cache (modelopt/torch/speculative/plugins/megatron_eagle.py), the StaticInferenceContext buffer size is correctly calculated as `input_ids.shape[1] * (parallel_draft_step + ttt_steps - 1)` because each ttt_step iteration discards (parallel_draft_step - 1) worth of KV entries by decrementing sequence_len_offset, so only one token's worth of KV cache is retained per ttt_step.
📚 Learning: 2025-10-01T21:34:30.854Z
Learnt from: yeyu-nvidia
PR: NVIDIA/TensorRT-Model-Optimizer#391
File: modelopt/torch/speculative/plugins/megatron_eagle.py:1036-1040
Timestamp: 2025-10-01T21:34:30.854Z
Learning: In EAGLE parallel draft with KV cache (modelopt/torch/speculative/plugins/megatron_eagle.py), the StaticInferenceContext buffer size is correctly calculated as `input_ids.shape[1] * (parallel_draft_step + ttt_steps - 1)` because each ttt_step iteration discards (parallel_draft_step - 1) worth of KV entries by decrementing sequence_len_offset, so only one token's worth of KV cache is retained per ttt_step.

Applied to files:

  • modelopt/torch/speculative/plugins/megatron_eagle.py
📚 Learning: 2025-09-29T17:40:37.310Z
Learnt from: yeyu-nvidia
PR: NVIDIA/TensorRT-Model-Optimizer#391
File: modelopt/torch/speculative/plugins/megatron_eagle.py:1049-1071
Timestamp: 2025-09-29T17:40:37.310Z
Learning: In EAGLE parallel draft implementation in modelopt/torch/speculative/plugins/megatron_eagle.py, overwriting eagle_hidden_states_*_pre_norm variables within the per-draft loop (rather than accumulating all draft pre-norm states) is the expected behavior for parallel draft functionality.

Applied to files:

  • modelopt/torch/speculative/plugins/megatron_eagle.py
🧬 Code graph analysis (1)
modelopt/torch/speculative/plugins/megatron_eagle.py (2)
modelopt/torch/speculative/plugins/transformers.py (2)
  • _get_eagle_module_inputs (501-548)
  • _eagle_forward (632-657)
modelopt/torch/speculative/utils.py (1)
  • get_default_attention_mask_and_position_ids (49-66)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
  • GitHub Check: linux
  • GitHub Check: wait-checks / wait
  • GitHub Check: wait-checks / wait
  • GitHub Check: code-quality
  • GitHub Check: build-docs
🔇 Additional comments (9)
modelopt/torch/speculative/plugins/megatron_eagle.py (9)

29-29: LGTM: Clean imports with proper fallback handling.

The StaticInferenceContext import and IGNORE_TOKEN_ID fallback are well-structured.

Also applies to: 71-76


552-552: LGTM: Clean inference context threading.

The inference_context parameter is properly added to the signature and threaded through to the decoder for KV cache support.

Also applies to: 593-593


801-846: LGTM: Parallel draft input preparation aligns with design.

The refactored _get_eagle_module_inputs properly handles multi-step and parallel draft indexing. The conditional logic for using learnable parallel_draft_embeddings and parallel_draft_hidden_states (when parallel_draft_index > 0) is correct. Based on learnings.


938-938: LGTM: Proper KV cache management and return value extension.

The inference_context is correctly threaded through, and the sequence_len_offset update (lines 952-954) properly tracks the cache position after each EAGLE module call. The extended return tuple including pre-norm hidden states is necessary for chaining multiple draft steps.

Also applies to: 948-948, 952-965


1020-1024: LGTM: KV cache buffer size correctly calculated.

The eagle_inference_context buffer size calculation input_ids.shape[1] * (parallel_draft_step + ttt_steps - 1) is correct. Based on learnings.


1072-1128: LGTM: Parallel draft loop structure matches expected behavior.

The nested loop structure over ttt_steps and parallel_draft_step correctly:

  • Calls EAGLE forward for each draft step
  • Retains only the first draft's pre-norm states for the next iteration (line 1095)
  • Updates hidden states with proper sequence parallel handling (lines 1100-1121)
  • Adjusts KV cache offset by discarding the last (parallel_draft_step - 1) entries (lines 1123-1128)

This aligns with the design documented in learnings.


1449-1474: LGTM: Draft token generation and cleanup logic is correct.

The draft token extraction from the last parallel_draft_step positions (lines 1449-1455) and the subsequent cleanup of dummy tokens (lines 1461-1465) correctly maintain the sequence for iterative drafting.


1346-1347: The review comment is incorrect and should be disregarded.

The code is internally consistent. The _eagle_forward method's inference_context parameter defaults to None. The main forward method unconditionally creates and passes eagle_inference_context, while pseudo_speculative_generate does not pass this parameter, so it defaults to None. These represent intentionally different implementations, not an inconsistency:

  • pseudo_speculative_generate does not create or pass inference_context → no KV cache support
  • Main forward creates and passes inference_context → KV cache is used

The comment at line 1346-1347 accurately documents that pseudo_speculative_generate does not support KV cache, which is correct given it doesn't invoke the KV cache infrastructure at all.

Likely an incorrect or invalid review comment.


1135-1141: Loss decay indexing is correct and intentional.

The offset-by-one pattern aligns EAGLE predictions properly with target tokens. The _compute_eagle_loss function returns shape [b, s-1] because it compares eagle_logits[:-1] with logits[1:] (next-token shift built-in). Slicing loss_[:, i + ttt_step:] from this tensor and assigning to loss[:, i + ttt_step + 1:] maintains correct alignment. The accuracy reporting code (line 1153) uses the identical offset pattern, confirming this is by design. Test coverage validates output shape [batch_size, max_sequence_length] is correct.

Comment on lines +1053 to +1070
if labels is not None:
if labels.shape[1] == input_ids.shape[1] - 1:
# For offline training, labels may be 1 token shorter than input_ids.
# We will just pad a 0 to the labels to make the seq_len the same as
# input_ids. This will introduce a small error in training if logit_distillation
# is False, and testing accuracy is wrong for the last token.
right_token_pad = torch.zeros(
(labels.shape[0], 1),
dtype=labels.dtype,
device=labels.device,
)
labels = torch.cat((labels, right_token_pad), dim=-1)

# If eagle_freeze_base_model is set to True,
# the base model is frozen .
loss = self.compute_language_model_loss(labels, logits_sbh)
loss = 0.0 * loss
# If eagle_freeze_base_model is set to True,
# the base model is frozen .
loss = self.compute_language_model_loss(labels, logits_sbh)
if self.eagle_freeze_base_model:
loss = 0.0 * loss
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Label padding workaround introduces training errors.

Lines 1054-1064 pad labels with zeros when they're 1 token shorter than input_ids. The comment acknowledges this "will introduce a small error in training" and cause "testing accuracy is wrong for the last token." While this is marked as future work, it could accumulate errors during training.

Consider tracking this as a known limitation in the PR description or documentation, and prioritize a proper fix.

Signed-off-by: Ye Yu <[email protected]>
Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (1)
modelopt/torch/speculative/plugins/megatron_eagle.py (1)

1046-1063: Label padding workaround persists from previous review.

This is the same label padding issue flagged in the previous review. The workaround pads labels with zeros when they're 1 token shorter than input_ids, which "will introduce a small error in training" per the inline comment. Since this is marked as future work and the PR is a draft prototype, documenting this as a known limitation is acceptable for now.

🧹 Nitpick comments (4)
modelopt/torch/speculative/plugins/megatron_eagle.py (4)

143-143: Minor: Docstring wording could be clearer.

The docstring says "greater than this version" but could be more precise: "Check if megatron-core version is strictly greater than the target version."


484-490: Consider better initialization for learnable draft parameters.

The parallel draft embeddings and hidden states are initialized with torch.rand(), which produces uniform [0, 1) values. This may not be optimal for neural network parameters. Consider using Xavier/Kaiming initialization or normal distribution with appropriate scaling.

Apply this diff to use Xavier uniform initialization:

         # Set up learnable parallel draft embeddings and hidden_states
         if config.parallel_draft_step > 1:
-            self.parallel_draft_embeddings = torch.nn.Parameter(
-                torch.rand(config.parallel_draft_step - 1, config.hidden_size)
-            )
-            self.parallel_draft_hidden_states = torch.nn.Parameter(
-                torch.rand(config.parallel_draft_step - 1, config.hidden_size)
-            )
+            parallel_draft_embeddings = torch.empty(config.parallel_draft_step - 1, config.hidden_size)
+            torch.nn.init.xavier_uniform_(parallel_draft_embeddings)
+            self.parallel_draft_embeddings = torch.nn.Parameter(parallel_draft_embeddings)
+            
+            parallel_draft_hidden_states = torch.empty(config.parallel_draft_step - 1, config.hidden_size)
+            torch.nn.init.xavier_uniform_(parallel_draft_hidden_states)
+            self.parallel_draft_hidden_states = torch.nn.Parameter(parallel_draft_hidden_states)

1339-1340: Document KV cache limitation with sequence parallelism.

The comment notes that KV cache is not supported with sequence parallel, but doesn't explain why. Consider documenting the technical reason (e.g., cache partitioning conflicts, synchronization issues) to help future maintainers understand this constraint.


1389-1467: Dummy token scaffolding adds complexity.

The approach of:

  1. Padding dummy tokens and hidden states (lines 1389-1393)
  2. Replacing them with parallel_draft parameters (lines 1414-1419)
  3. Removing dummy tokens before adding real tokens (lines 1456-1458)

...introduces significant complexity. While this may be necessary given the current architecture, consider whether a more direct approach is feasible in future iterations. The current implementation works but is harder to understand and maintain.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between fc4551a and a663526.

📒 Files selected for processing (1)
  • modelopt/torch/speculative/plugins/megatron_eagle.py (17 hunks)
🧰 Additional context used
🧠 Learnings (3)
📓 Common learnings
Learnt from: yeyu-nvidia
PR: NVIDIA/TensorRT-Model-Optimizer#391
File: modelopt/torch/speculative/plugins/megatron_eagle.py:1036-1040
Timestamp: 2025-10-01T21:34:30.854Z
Learning: In EAGLE parallel draft with KV cache (modelopt/torch/speculative/plugins/megatron_eagle.py), the StaticInferenceContext buffer size is correctly calculated as `input_ids.shape[1] * (parallel_draft_step + ttt_steps - 1)` because each ttt_step iteration discards (parallel_draft_step - 1) worth of KV entries by decrementing sequence_len_offset, so only one token's worth of KV cache is retained per ttt_step.
Learnt from: yeyu-nvidia
PR: NVIDIA/TensorRT-Model-Optimizer#391
File: modelopt/torch/speculative/plugins/megatron_eagle.py:1049-1071
Timestamp: 2025-09-29T17:40:37.310Z
Learning: In EAGLE parallel draft implementation in modelopt/torch/speculative/plugins/megatron_eagle.py, overwriting eagle_hidden_states_*_pre_norm variables within the per-draft loop (rather than accumulating all draft pre-norm states) is the expected behavior for parallel draft functionality.
📚 Learning: 2025-09-29T17:40:37.310Z
Learnt from: yeyu-nvidia
PR: NVIDIA/TensorRT-Model-Optimizer#391
File: modelopt/torch/speculative/plugins/megatron_eagle.py:1049-1071
Timestamp: 2025-09-29T17:40:37.310Z
Learning: In EAGLE parallel draft implementation in modelopt/torch/speculative/plugins/megatron_eagle.py, overwriting eagle_hidden_states_*_pre_norm variables within the per-draft loop (rather than accumulating all draft pre-norm states) is the expected behavior for parallel draft functionality.

Applied to files:

  • modelopt/torch/speculative/plugins/megatron_eagle.py
📚 Learning: 2025-10-01T21:34:30.854Z
Learnt from: yeyu-nvidia
PR: NVIDIA/TensorRT-Model-Optimizer#391
File: modelopt/torch/speculative/plugins/megatron_eagle.py:1036-1040
Timestamp: 2025-10-01T21:34:30.854Z
Learning: In EAGLE parallel draft with KV cache (modelopt/torch/speculative/plugins/megatron_eagle.py), the StaticInferenceContext buffer size is correctly calculated as `input_ids.shape[1] * (parallel_draft_step + ttt_steps - 1)` because each ttt_step iteration discards (parallel_draft_step - 1) worth of KV entries by decrementing sequence_len_offset, so only one token's worth of KV cache is retained per ttt_step.

Applied to files:

  • modelopt/torch/speculative/plugins/megatron_eagle.py
🧬 Code graph analysis (1)
modelopt/torch/speculative/plugins/megatron_eagle.py (2)
modelopt/torch/speculative/plugins/transformers.py (2)
  • _get_eagle_module_inputs (501-548)
  • _eagle_forward (632-657)
modelopt/torch/speculative/utils.py (1)
  • get_default_attention_mask_and_position_ids (49-66)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
  • GitHub Check: linux
  • GitHub Check: wait-checks / wait
  • GitHub Check: wait-checks / wait
  • GitHub Check: code-quality
  • GitHub Check: build-docs
🔇 Additional comments (9)
modelopt/torch/speculative/plugins/megatron_eagle.py (9)

29-29: LGTM: Import for KV cache implementation.

The StaticInferenceContext import is necessary for the KV cache functionality being added.


190-277: LGTM: Multi-step attention mask refactored to support parallel draft.

The refactoring from fixed nested blocks to a per-step loop with dynamic mask construction properly supports the new ttt_step and parallel_draft_step semantics. The extensive documentation with examples is helpful for understanding the complex masking logic.


788-838: LGTM: Refactored to support parallel draft with learned embeddings.

The function now properly handles parallel draft by:

  • Using real embeddings for the first draft (parallel_draft_index == 0)
  • Using learnable parallel_draft_embeddings and parallel_draft_hidden_states for subsequent drafts
  • Adjusting attention masks and rotary embeddings accordingly

This aligns with the retrieved learnings about parallel draft behavior.

Based on learnings


925-958: LGTM: Proper inference context threading and cache offset management.

The inference_context is correctly:

  1. Added as an optional parameter
  2. Propagated to eagle_module
  3. Updated with sequence_len_offset after each call for KV cache management

1013-1017: LGTM: KV cache buffer sized correctly.

The StaticInferenceContext buffer size calculation input_ids.shape[1] * (parallel_draft_step + ttt_steps - 1) is correct based on the cache management strategy where each ttt_step discards (parallel_draft_step - 1) worth of KV entries.

Based on learnings


1065-1121: LGTM: Training loop refactored for multi-step parallel draft.

The nested loop structure properly implements:

  • Outer loop over ttt_steps for temporal processing
  • Inner loop over parallel_draft_step for parallel token generation
  • Correct KV cache offset management (line 1119-1121) that discards cache for last (parallel_draft_step - 1) tokens

The complexity is inherent to the parallel draft algorithm.

Based on learnings


1136-1157: LGTM: Accuracy reporting properly implemented.

The accuracy reporting:

  • Correctly gates on eagle_report_acc and inference mode
  • Handles draft vocab size mapping
  • Only prints from TP rank 0 to avoid duplicates
  • Calculates top-1 accuracy for each parallel draft position

1407-1423: LGTM: Sequence parallel handling is correct but complex.

The gather/scatter pattern for sequence parallelism with the dummy token replacement is correctly implemented. The complexity is inherent to supporting both sequence parallelism and parallel draft functionality.


545-545: No action needed - code is compatible with project requirements.

The project's setup.py specifies python_requires=">=3.10,<3.13", which fully supports PEP 604 union syntax. Additionally, the codebase already extensively uses this syntax throughout test files and examples, confirming it is an established pattern. The inference_context: StaticInferenceContext | None annotation aligns with both the Python version requirement and existing code conventions.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant