-
Notifications
You must be signed in to change notification settings - Fork 177
Yeyu/debug paralllel draft #429
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Ye Yu <[email protected]>
Signed-off-by: Ye Yu <[email protected]>
Signed-off-by: Ye Yu <[email protected]>
Signed-off-by: Ye Yu <[email protected]>
Signed-off-by: Ye Yu <[email protected]>
Signed-off-by: Ye Yu <[email protected]>
Signed-off-by: Ye Yu <[email protected]>
Signed-off-by: Ye Yu <[email protected]>
Signed-off-by: Ye Yu <[email protected]>
Signed-off-by: Ye Yu <[email protected]>
Signed-off-by: Ye Yu <[email protected]>
Signed-off-by: Ye Yu <[email protected]>
Signed-off-by: Ye Yu <[email protected]>
Signed-off-by: Ye Yu <[email protected]>
Signed-off-by: Ye Yu <[email protected]>
Signed-off-by: Ye Yu <[email protected]>
Signed-off-by: Ye Yu <[email protected]>
Signed-off-by: Ye Yu <[email protected]>
Signed-off-by: Ye Yu <[email protected]>
Signed-off-by: Ye Yu <[email protected]>
…ach ttt step, only the non_parallel tokens from previous ttt are used as context Signed-off-by: Ye Yu <[email protected]>
Signed-off-by: Ye Yu <[email protected]>
Signed-off-by: Ye Yu <[email protected]>
Signed-off-by: Ye Yu <[email protected]>
Signed-off-by: Ye Yu <[email protected]>
Signed-off-by: Ye Yu <[email protected]>
…tead Signed-off-by: Ye Yu <[email protected]>
Signed-off-by: Ye Yu <[email protected]>
Signed-off-by: Ye Yu <[email protected]>
Signed-off-by: Ye Yu <[email protected]>
…r base model param freeze Signed-off-by: Ye Yu <[email protected]>
Signed-off-by: Ye Yu <[email protected]>
WalkthroughThreads a StaticInferenceContext through Megatron EAGLE forward paths, refactors per-step attention-mask construction, adds conditional parallel-draft embeddings/hidden-states and wiring, updates forward signatures and inference flow (including pseudo_speculative_generate), and relaxes cross-rank consistency checks to warn-and-coerce. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
participant Caller as MegatronEagle.forward
participant Emb as Embeddings
participant Eagle as EagleModule
participant Ctx as StaticInferenceContext
participant Mask as Mask/Cache Builder
Caller->>Mask: build per-step attention masks, pos_ids, rotary
alt parallel_draft_step > 1
Caller->>Emb: use base + parallel_draft_embeddings / parallel_draft_hidden_states
else
Caller->>Emb: compute standard embeddings / hidden_states
end
Caller->>Ctx: init eagle_inference_context (capacity, offsets)
loop for each ttt_step
Caller->>Eagle: forward(inputs..., ttt_step, parallel_draft_index, inference_context=Ctx)
Eagle-->>Caller: logits, hidden_states, pre-norm states
Caller->>Ctx: update sequence_len_offset (kv-cache-like)
Caller->>Mask: adjust caches/masks for next step
end
Caller-->>Caller: assemble drafts / return outputs
note right of Caller: pseudo_speculative_generate pads and replaces drafts when parallel drafting enabled
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Poem
Pre-merge checks and finishing touches❌ Failed checks (1 inconclusive)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
modelopt/torch/speculative/utils.py (1)
293-317
: Enforce strict data‐consistency failure at call sites
All four calls (lines 336, 351, 356, 358 inmodelopt/torch/speculative/utils.py
) now use the defaultfail_when_mismatch=False
, emitting only warnings on mismatch. Explicitly passfail_when_mismatch=True
at each call if divergence should remain fatal.
🧹 Nitpick comments (1)
modelopt/torch/speculative/plugins/megatron_eagle.py (1)
484-490
: Consider initialization strategy for learnable draft parameters.The parallel draft embeddings and hidden states are initialized with
torch.rand
, which produces values in [0, 1). For better training stability, consider:
- Zero initialization:
torch.zeros
- Small random initialization:
torch.randn(...) * 0.01
- Xavier/Kaiming initialization based on the hidden size
Random uniform initialization in [0, 1) can lead to large gradients early in training.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
modelopt/torch/speculative/plugins/megatron_eagle.py
(17 hunks)modelopt/torch/speculative/utils.py
(2 hunks)
🧰 Additional context used
🧠 Learnings (3)
📓 Common learnings
Learnt from: yeyu-nvidia
PR: NVIDIA/TensorRT-Model-Optimizer#391
File: modelopt/torch/speculative/plugins/megatron_eagle.py:1036-1040
Timestamp: 2025-10-01T21:34:30.854Z
Learning: In EAGLE parallel draft with KV cache (modelopt/torch/speculative/plugins/megatron_eagle.py), the StaticInferenceContext buffer size is correctly calculated as `input_ids.shape[1] * (parallel_draft_step + ttt_steps - 1)` because each ttt_step iteration discards (parallel_draft_step - 1) worth of KV entries by decrementing sequence_len_offset, so only one token's worth of KV cache is retained per ttt_step.
Learnt from: yeyu-nvidia
PR: NVIDIA/TensorRT-Model-Optimizer#391
File: modelopt/torch/speculative/plugins/megatron_eagle.py:1049-1071
Timestamp: 2025-09-29T17:40:37.310Z
Learning: In EAGLE parallel draft implementation in modelopt/torch/speculative/plugins/megatron_eagle.py, overwriting eagle_hidden_states_*_pre_norm variables within the per-draft loop (rather than accumulating all draft pre-norm states) is the expected behavior for parallel draft functionality.
📚 Learning: 2025-10-01T21:34:30.854Z
Learnt from: yeyu-nvidia
PR: NVIDIA/TensorRT-Model-Optimizer#391
File: modelopt/torch/speculative/plugins/megatron_eagle.py:1036-1040
Timestamp: 2025-10-01T21:34:30.854Z
Learning: In EAGLE parallel draft with KV cache (modelopt/torch/speculative/plugins/megatron_eagle.py), the StaticInferenceContext buffer size is correctly calculated as `input_ids.shape[1] * (parallel_draft_step + ttt_steps - 1)` because each ttt_step iteration discards (parallel_draft_step - 1) worth of KV entries by decrementing sequence_len_offset, so only one token's worth of KV cache is retained per ttt_step.
Applied to files:
modelopt/torch/speculative/plugins/megatron_eagle.py
📚 Learning: 2025-09-29T17:40:37.310Z
Learnt from: yeyu-nvidia
PR: NVIDIA/TensorRT-Model-Optimizer#391
File: modelopt/torch/speculative/plugins/megatron_eagle.py:1049-1071
Timestamp: 2025-09-29T17:40:37.310Z
Learning: In EAGLE parallel draft implementation in modelopt/torch/speculative/plugins/megatron_eagle.py, overwriting eagle_hidden_states_*_pre_norm variables within the per-draft loop (rather than accumulating all draft pre-norm states) is the expected behavior for parallel draft functionality.
Applied to files:
modelopt/torch/speculative/plugins/megatron_eagle.py
🧬 Code graph analysis (2)
modelopt/torch/speculative/utils.py (2)
modelopt/torch/speculative/plugins/megatron_eagle.py (1)
pseudo_speculative_generate
(1332-1471)modelopt/torch/speculative/plugins/transformers.py (1)
pseudo_speculative_generate
(861-942)
modelopt/torch/speculative/plugins/megatron_eagle.py (2)
modelopt/torch/speculative/plugins/transformers.py (2)
_get_eagle_module_inputs
(501-548)_eagle_forward
(632-657)modelopt/torch/speculative/utils.py (1)
get_default_attention_mask_and_position_ids
(49-66)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
- GitHub Check: linux
- GitHub Check: wait-checks / wait
- GitHub Check: wait-checks / wait
- GitHub Check: build-docs
- GitHub Check: code-quality
🔇 Additional comments (15)
modelopt/torch/speculative/utils.py (1)
336-363
: Verify behavior change at lines 336 and 358.With the new default
fail_when_mismatch=False
, the calls tocheck_data_consistency_across_ranks
at lines 336 and 358 will now warn and coerce data instead of raising an error on mismatch. This is an implicit behavior change that may not be obvious to maintainers.If these call sites require strict consistency (e.g., for ground truth validation), explicitly pass
fail_when_mismatch=True
.modelopt/torch/speculative/plugins/megatron_eagle.py (14)
29-29
: LGTM!The
StaticInferenceContext
import is correctly added to support KV cache functionality in EAGLE forward passes.
142-144
: LGTM!The docstring clarification improves accuracy.
190-277
: LGTM: Dynamic multi-step mask construction.The refactored implementation correctly generalizes the attention mask construction to handle an arbitrary number of steps, which is essential for the parallel draft feature. The loop-based approach (lines 267-276) builds mask segments dynamically and concatenates them.
The detailed docstring examples help understand the complex masking pattern for ttt_steps=2 and parallel_draft_step=3.
545-546
: LGTM!The
inference_context
parameter is correctly threaded through the Eagle module to enable KV cache support.Also applies to: 586-586
788-839
: LGTM: Multi-step draft input preparation.The refactored logic correctly handles both regular (parallel_draft_index=0) and parallel draft (parallel_draft_index>0) scenarios:
- For regular draft: uses actual token embeddings from input_ids
- For parallel draft: uses learnable parameters that were initialized in EagleModule.init
The attention mask replication (line 830-832) and rotary embedding concatenation (line 834-837) align with the multi-step processing pattern.
931-958
: LGTM: KV cache offset management and return value extension.The changes correctly:
- Thread
inference_context
through to the Eagle module (line 941)- Update
sequence_len_offset
after each Eagle forward pass (lines 946-947) to track KV cache consumption- Return pre-final-layernorm hidden states (line 957) for use as input to subsequent iterations
The sequence_len_offset increment by
input_ids.shape[1]
correctly accounts for the tokens processed in this Eagle forward pass.Based on learnings.
1014-1017
: LGTM: Correct KV cache capacity calculation.The StaticInferenceContext capacity is correctly calculated as
input_ids.shape[1] * (parallel_draft_step + ttt_steps - 1)
. This accounts for the KV entries retained across ttt_step iterations, where each iteration discards (parallel_draft_step - 1) entries by decrementing sequence_len_offset.Based on learnings.
1046-1063
: LGTM: Label padding for offline training.The label padding logic (lines 1046-1057) correctly handles cases where labels are one token shorter than input_ids in offline training scenarios. The comment on line 1050-1052 appropriately notes the small error introduced for the last token when not using logit distillation.
1065-1121
: LGTM: Nested loop structure with correct KV cache management.The nested loop structure correctly implements the parallel draft algorithm:
- Outer loop (line 1066): Iterates over ttt_steps
- Inner loop (line 1068): Iterates over parallel_draft_step
- KV cache offset adjustment (lines 1119-1121): Correctly decrements by
input_ids.shape[1] * (parallel_draft_step - 1)
after each ttt_stepLine 1088 correctly overwrites
next_eagle_hidden_states_pre_norm
on each inner iteration (only storing the i==0 result), which aligns with the expected behavior noted in the learnings.Based on learnings.
1097-1114
: Verify the hidden state padding logic.The hidden state manipulation (lines 1097-1114) prepends a zero tensor and shifts the hidden states. Ensure this logic correctly aligns with the attention mask and position IDs for the next ttt_step iteration.
Consider adding a comment explaining why this padding/shifting is necessary for the next iteration.
1128-1157
: LGTM: Loss computation and accuracy tracking.The loss computation correctly:
- Slices logits into chunks of
input_ids.shape[1]
(lines 1129-1130)- Applies decay factor based on ttt_step and parallel_draft_index (line 1132-1133)
- Accumulates weighted losses (line 1132-1134)
The accuracy tracking (lines 1136-1150) properly handles vocab size adjustments and reports only on rank 0 to avoid duplicate output.
1339-1340
: LGTM: Accurate docstring update.The docstring correctly notes that KV cache is not supported in this function when sequence parallel is enabled. This is an important limitation for users to be aware of.
1389-1423
: LGTM: Parallel draft embedding replacement.The implementation correctly:
- Pads dummy tokens and hidden states (lines 1389-1393)
- Gathers embeddings when sequence parallel is used (lines 1407-1410)
- Replaces dummy entries with learnable parallel_draft_embeddings and parallel_draft_hidden_states (lines 1414-1419)
- Scatters back if sequence parallel (lines 1420-1423)
The conditional handling of sequence_parallel ensures correctness in both modes.
1442-1467
: LGTM: Multi-token draft extraction and cleanup.The draft token extraction correctly:
- Extracts the last
parallel_draft_step
tokens from logits (lines 1443-1445)- Applies vocab mapping if needed (lines 1449-1450)
- Removes dummy tokens before appending draft tokens (lines 1456-1458)
- Appends the correct number of hidden states for the next iteration (lines 1461-1467)
The cleanup logic ensures that dummy padding is removed before progressing to the next draft step.
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #429 +/- ##
==========================================
- Coverage 73.38% 73.37% -0.01%
==========================================
Files 180 180
Lines 17934 17937 +3
==========================================
+ Hits 13160 13162 +2
- Misses 4774 4775 +1 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
modelopt/torch/speculative/plugins/megatron_eagle.py (1)
197-284
: NO dedicated unit tests found forset_multi_step_attention_mask
—add tests to verify mask construction correctness.The refactored function implements complex multi-step attention mask logic for TTT steps and parallel draft indices. While integration tests exist for the eagle module (in
test_speculative_megatron_modules.py
), they do not directly validate the correctness of this mask construction. The function is called in production code at lines 837-838 and is critical for correct attention computation.Recommended actions:
- Add unit tests for
set_multi_step_attention_mask
with specific test cases:
- Verify output mask shape matches expected dimensions
- Validate mask values (True/False) correspond to documented attention patterns in the docstring
- Test edge cases: step=0, step=1, various batch and sequence lengths
- Cross-check mask construction against the ASCII diagrams in the docstring for ttt_step=0 and ttt_step=1
🧹 Nitpick comments (3)
modelopt/torch/speculative/plugins/megatron_eagle.py (3)
150-150
: Minor: Docstring wording could be more consistent with function name.The docstring says "greater than" but the function is named
mcore_version_higher_than
. Consider using "higher than" to match the function name, or this is acceptable as-is.
490-497
: Use proper initialization strategy for learnable parameters.The parallel draft embeddings and hidden states are initialized with
torch.rand
, which produces values uniformly distributed in [0, 1). This is not ideal for neural network parameters and could contribute to convergence issues.Consider using proper initialization strategies like Xavier/Kaiming initialization:
- if config.parallel_draft_step > 1: - self.parallel_draft_embeddings = torch.nn.Parameter( - torch.rand(config.parallel_draft_step - 1, config.hidden_size) - ) - self.parallel_draft_hidden_states = torch.nn.Parameter( - torch.rand(config.parallel_draft_step - 1, config.hidden_size) - ) + if config.parallel_draft_step > 1: + self.parallel_draft_embeddings = torch.nn.Parameter( + torch.zeros(config.parallel_draft_step - 1, config.hidden_size) + ) + torch.nn.init.xavier_uniform_(self.parallel_draft_embeddings) + self.parallel_draft_hidden_states = torch.nn.Parameter( + torch.zeros(config.parallel_draft_step - 1, config.hidden_size) + ) + torch.nn.init.xavier_uniform_(self.parallel_draft_hidden_states)This could help address convergence issues mentioned in the PR description.
1396-1430
: Consider preallocating parallel draft embeddings instead of padding then replacing.The current approach pads dummy tokens and hidden states (lines 1396-1400) then replaces them with parallel draft values (lines 1419-1426). This is inefficient and adds complexity.
Consider directly constructing the tensors with the correct values, or building a helper function that handles this logic more cleanly.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
modelopt/torch/speculative/plugins/megatron_eagle.py
(18 hunks)
🧰 Additional context used
🧠 Learnings (3)
📓 Common learnings
Learnt from: yeyu-nvidia
PR: NVIDIA/TensorRT-Model-Optimizer#391
File: modelopt/torch/speculative/plugins/megatron_eagle.py:1049-1071
Timestamp: 2025-09-29T17:40:37.310Z
Learning: In EAGLE parallel draft implementation in modelopt/torch/speculative/plugins/megatron_eagle.py, overwriting eagle_hidden_states_*_pre_norm variables within the per-draft loop (rather than accumulating all draft pre-norm states) is the expected behavior for parallel draft functionality.
Learnt from: yeyu-nvidia
PR: NVIDIA/TensorRT-Model-Optimizer#391
File: modelopt/torch/speculative/plugins/megatron_eagle.py:1036-1040
Timestamp: 2025-10-01T21:34:30.854Z
Learning: In EAGLE parallel draft with KV cache (modelopt/torch/speculative/plugins/megatron_eagle.py), the StaticInferenceContext buffer size is correctly calculated as `input_ids.shape[1] * (parallel_draft_step + ttt_steps - 1)` because each ttt_step iteration discards (parallel_draft_step - 1) worth of KV entries by decrementing sequence_len_offset, so only one token's worth of KV cache is retained per ttt_step.
📚 Learning: 2025-10-01T21:34:30.854Z
Learnt from: yeyu-nvidia
PR: NVIDIA/TensorRT-Model-Optimizer#391
File: modelopt/torch/speculative/plugins/megatron_eagle.py:1036-1040
Timestamp: 2025-10-01T21:34:30.854Z
Learning: In EAGLE parallel draft with KV cache (modelopt/torch/speculative/plugins/megatron_eagle.py), the StaticInferenceContext buffer size is correctly calculated as `input_ids.shape[1] * (parallel_draft_step + ttt_steps - 1)` because each ttt_step iteration discards (parallel_draft_step - 1) worth of KV entries by decrementing sequence_len_offset, so only one token's worth of KV cache is retained per ttt_step.
Applied to files:
modelopt/torch/speculative/plugins/megatron_eagle.py
📚 Learning: 2025-09-29T17:40:37.310Z
Learnt from: yeyu-nvidia
PR: NVIDIA/TensorRT-Model-Optimizer#391
File: modelopt/torch/speculative/plugins/megatron_eagle.py:1049-1071
Timestamp: 2025-09-29T17:40:37.310Z
Learning: In EAGLE parallel draft implementation in modelopt/torch/speculative/plugins/megatron_eagle.py, overwriting eagle_hidden_states_*_pre_norm variables within the per-draft loop (rather than accumulating all draft pre-norm states) is the expected behavior for parallel draft functionality.
Applied to files:
modelopt/torch/speculative/plugins/megatron_eagle.py
🧬 Code graph analysis (1)
modelopt/torch/speculative/plugins/megatron_eagle.py (2)
modelopt/torch/speculative/plugins/transformers.py (2)
_get_eagle_module_inputs
(501-548)_eagle_forward
(632-657)modelopt/torch/speculative/utils.py (1)
get_default_attention_mask_and_position_ids
(49-66)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
- GitHub Check: linux
- GitHub Check: wait-checks / wait
- GitHub Check: wait-checks / wait
- GitHub Check: code-quality
- GitHub Check: build-docs
🔇 Additional comments (9)
modelopt/torch/speculative/plugins/megatron_eagle.py (9)
29-29
: LGTM: Clean imports with proper fallback handling.The StaticInferenceContext import and IGNORE_TOKEN_ID fallback are well-structured.
Also applies to: 71-76
552-552
: LGTM: Clean inference context threading.The inference_context parameter is properly added to the signature and threaded through to the decoder for KV cache support.
Also applies to: 593-593
801-846
: LGTM: Parallel draft input preparation aligns with design.The refactored
_get_eagle_module_inputs
properly handles multi-step and parallel draft indexing. The conditional logic for using learnable parallel_draft_embeddings and parallel_draft_hidden_states (whenparallel_draft_index > 0
) is correct. Based on learnings.
938-938
: LGTM: Proper KV cache management and return value extension.The inference_context is correctly threaded through, and the sequence_len_offset update (lines 952-954) properly tracks the cache position after each EAGLE module call. The extended return tuple including pre-norm hidden states is necessary for chaining multiple draft steps.
Also applies to: 948-948, 952-965
1020-1024
: LGTM: KV cache buffer size correctly calculated.The eagle_inference_context buffer size calculation
input_ids.shape[1] * (parallel_draft_step + ttt_steps - 1)
is correct. Based on learnings.
1072-1128
: LGTM: Parallel draft loop structure matches expected behavior.The nested loop structure over ttt_steps and parallel_draft_step correctly:
- Calls EAGLE forward for each draft step
- Retains only the first draft's pre-norm states for the next iteration (line 1095)
- Updates hidden states with proper sequence parallel handling (lines 1100-1121)
- Adjusts KV cache offset by discarding the last (parallel_draft_step - 1) entries (lines 1123-1128)
This aligns with the design documented in learnings.
1449-1474
: LGTM: Draft token generation and cleanup logic is correct.The draft token extraction from the last
parallel_draft_step
positions (lines 1449-1455) and the subsequent cleanup of dummy tokens (lines 1461-1465) correctly maintain the sequence for iterative drafting.
1346-1347
: The review comment is incorrect and should be disregarded.The code is internally consistent. The
_eagle_forward
method'sinference_context
parameter defaults toNone
. The main forward method unconditionally creates and passeseagle_inference_context
, whilepseudo_speculative_generate
does not pass this parameter, so it defaults to None. These represent intentionally different implementations, not an inconsistency:
pseudo_speculative_generate
does not create or passinference_context
→ no KV cache support- Main forward creates and passes
inference_context
→ KV cache is usedThe comment at line 1346-1347 accurately documents that
pseudo_speculative_generate
does not support KV cache, which is correct given it doesn't invoke the KV cache infrastructure at all.Likely an incorrect or invalid review comment.
1135-1141
: Loss decay indexing is correct and intentional.The offset-by-one pattern aligns EAGLE predictions properly with target tokens. The
_compute_eagle_loss
function returns shape[b, s-1]
because it compareseagle_logits[:-1]
withlogits[1:]
(next-token shift built-in). Slicingloss_[:, i + ttt_step:]
from this tensor and assigning toloss[:, i + ttt_step + 1:]
maintains correct alignment. The accuracy reporting code (line 1153) uses the identical offset pattern, confirming this is by design. Test coverage validates output shape[batch_size, max_sequence_length]
is correct.
if labels is not None: | ||
if labels.shape[1] == input_ids.shape[1] - 1: | ||
# For offline training, labels may be 1 token shorter than input_ids. | ||
# We will just pad a 0 to the labels to make the seq_len the same as | ||
# input_ids. This will introduce a small error in training if logit_distillation | ||
# is False, and testing accuracy is wrong for the last token. | ||
right_token_pad = torch.zeros( | ||
(labels.shape[0], 1), | ||
dtype=labels.dtype, | ||
device=labels.device, | ||
) | ||
labels = torch.cat((labels, right_token_pad), dim=-1) | ||
|
||
# If eagle_freeze_base_model is set to True, | ||
# the base model is frozen . | ||
loss = self.compute_language_model_loss(labels, logits_sbh) | ||
loss = 0.0 * loss | ||
# If eagle_freeze_base_model is set to True, | ||
# the base model is frozen . | ||
loss = self.compute_language_model_loss(labels, logits_sbh) | ||
if self.eagle_freeze_base_model: | ||
loss = 0.0 * loss |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Label padding workaround introduces training errors.
Lines 1054-1064 pad labels with zeros when they're 1 token shorter than input_ids. The comment acknowledges this "will introduce a small error in training" and cause "testing accuracy is wrong for the last token." While this is marked as future work, it could accumulate errors during training.
Consider tracking this as a known limitation in the PR description or documentation, and prioritize a proper fix.
Signed-off-by: Ye Yu <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
♻️ Duplicate comments (1)
modelopt/torch/speculative/plugins/megatron_eagle.py (1)
1046-1063
: Label padding workaround persists from previous review.This is the same label padding issue flagged in the previous review. The workaround pads labels with zeros when they're 1 token shorter than input_ids, which "will introduce a small error in training" per the inline comment. Since this is marked as future work and the PR is a draft prototype, documenting this as a known limitation is acceptable for now.
🧹 Nitpick comments (4)
modelopt/torch/speculative/plugins/megatron_eagle.py (4)
143-143
: Minor: Docstring wording could be clearer.The docstring says "greater than this version" but could be more precise: "Check if megatron-core version is strictly greater than the target version."
484-490
: Consider better initialization for learnable draft parameters.The parallel draft embeddings and hidden states are initialized with
torch.rand()
, which produces uniform [0, 1) values. This may not be optimal for neural network parameters. Consider using Xavier/Kaiming initialization or normal distribution with appropriate scaling.Apply this diff to use Xavier uniform initialization:
# Set up learnable parallel draft embeddings and hidden_states if config.parallel_draft_step > 1: - self.parallel_draft_embeddings = torch.nn.Parameter( - torch.rand(config.parallel_draft_step - 1, config.hidden_size) - ) - self.parallel_draft_hidden_states = torch.nn.Parameter( - torch.rand(config.parallel_draft_step - 1, config.hidden_size) - ) + parallel_draft_embeddings = torch.empty(config.parallel_draft_step - 1, config.hidden_size) + torch.nn.init.xavier_uniform_(parallel_draft_embeddings) + self.parallel_draft_embeddings = torch.nn.Parameter(parallel_draft_embeddings) + + parallel_draft_hidden_states = torch.empty(config.parallel_draft_step - 1, config.hidden_size) + torch.nn.init.xavier_uniform_(parallel_draft_hidden_states) + self.parallel_draft_hidden_states = torch.nn.Parameter(parallel_draft_hidden_states)
1339-1340
: Document KV cache limitation with sequence parallelism.The comment notes that KV cache is not supported with sequence parallel, but doesn't explain why. Consider documenting the technical reason (e.g., cache partitioning conflicts, synchronization issues) to help future maintainers understand this constraint.
1389-1467
: Dummy token scaffolding adds complexity.The approach of:
- Padding dummy tokens and hidden states (lines 1389-1393)
- Replacing them with parallel_draft parameters (lines 1414-1419)
- Removing dummy tokens before adding real tokens (lines 1456-1458)
...introduces significant complexity. While this may be necessary given the current architecture, consider whether a more direct approach is feasible in future iterations. The current implementation works but is harder to understand and maintain.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
modelopt/torch/speculative/plugins/megatron_eagle.py
(17 hunks)
🧰 Additional context used
🧠 Learnings (3)
📓 Common learnings
Learnt from: yeyu-nvidia
PR: NVIDIA/TensorRT-Model-Optimizer#391
File: modelopt/torch/speculative/plugins/megatron_eagle.py:1036-1040
Timestamp: 2025-10-01T21:34:30.854Z
Learning: In EAGLE parallel draft with KV cache (modelopt/torch/speculative/plugins/megatron_eagle.py), the StaticInferenceContext buffer size is correctly calculated as `input_ids.shape[1] * (parallel_draft_step + ttt_steps - 1)` because each ttt_step iteration discards (parallel_draft_step - 1) worth of KV entries by decrementing sequence_len_offset, so only one token's worth of KV cache is retained per ttt_step.
Learnt from: yeyu-nvidia
PR: NVIDIA/TensorRT-Model-Optimizer#391
File: modelopt/torch/speculative/plugins/megatron_eagle.py:1049-1071
Timestamp: 2025-09-29T17:40:37.310Z
Learning: In EAGLE parallel draft implementation in modelopt/torch/speculative/plugins/megatron_eagle.py, overwriting eagle_hidden_states_*_pre_norm variables within the per-draft loop (rather than accumulating all draft pre-norm states) is the expected behavior for parallel draft functionality.
📚 Learning: 2025-09-29T17:40:37.310Z
Learnt from: yeyu-nvidia
PR: NVIDIA/TensorRT-Model-Optimizer#391
File: modelopt/torch/speculative/plugins/megatron_eagle.py:1049-1071
Timestamp: 2025-09-29T17:40:37.310Z
Learning: In EAGLE parallel draft implementation in modelopt/torch/speculative/plugins/megatron_eagle.py, overwriting eagle_hidden_states_*_pre_norm variables within the per-draft loop (rather than accumulating all draft pre-norm states) is the expected behavior for parallel draft functionality.
Applied to files:
modelopt/torch/speculative/plugins/megatron_eagle.py
📚 Learning: 2025-10-01T21:34:30.854Z
Learnt from: yeyu-nvidia
PR: NVIDIA/TensorRT-Model-Optimizer#391
File: modelopt/torch/speculative/plugins/megatron_eagle.py:1036-1040
Timestamp: 2025-10-01T21:34:30.854Z
Learning: In EAGLE parallel draft with KV cache (modelopt/torch/speculative/plugins/megatron_eagle.py), the StaticInferenceContext buffer size is correctly calculated as `input_ids.shape[1] * (parallel_draft_step + ttt_steps - 1)` because each ttt_step iteration discards (parallel_draft_step - 1) worth of KV entries by decrementing sequence_len_offset, so only one token's worth of KV cache is retained per ttt_step.
Applied to files:
modelopt/torch/speculative/plugins/megatron_eagle.py
🧬 Code graph analysis (1)
modelopt/torch/speculative/plugins/megatron_eagle.py (2)
modelopt/torch/speculative/plugins/transformers.py (2)
_get_eagle_module_inputs
(501-548)_eagle_forward
(632-657)modelopt/torch/speculative/utils.py (1)
get_default_attention_mask_and_position_ids
(49-66)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
- GitHub Check: linux
- GitHub Check: wait-checks / wait
- GitHub Check: wait-checks / wait
- GitHub Check: code-quality
- GitHub Check: build-docs
🔇 Additional comments (9)
modelopt/torch/speculative/plugins/megatron_eagle.py (9)
29-29
: LGTM: Import for KV cache implementation.The StaticInferenceContext import is necessary for the KV cache functionality being added.
190-277
: LGTM: Multi-step attention mask refactored to support parallel draft.The refactoring from fixed nested blocks to a per-step loop with dynamic mask construction properly supports the new ttt_step and parallel_draft_step semantics. The extensive documentation with examples is helpful for understanding the complex masking logic.
788-838
: LGTM: Refactored to support parallel draft with learned embeddings.The function now properly handles parallel draft by:
- Using real embeddings for the first draft (parallel_draft_index == 0)
- Using learnable parallel_draft_embeddings and parallel_draft_hidden_states for subsequent drafts
- Adjusting attention masks and rotary embeddings accordingly
This aligns with the retrieved learnings about parallel draft behavior.
Based on learnings
925-958
: LGTM: Proper inference context threading and cache offset management.The inference_context is correctly:
- Added as an optional parameter
- Propagated to eagle_module
- Updated with sequence_len_offset after each call for KV cache management
1013-1017
: LGTM: KV cache buffer sized correctly.The StaticInferenceContext buffer size calculation
input_ids.shape[1] * (parallel_draft_step + ttt_steps - 1)
is correct based on the cache management strategy where each ttt_step discards (parallel_draft_step - 1) worth of KV entries.Based on learnings
1065-1121
: LGTM: Training loop refactored for multi-step parallel draft.The nested loop structure properly implements:
- Outer loop over ttt_steps for temporal processing
- Inner loop over parallel_draft_step for parallel token generation
- Correct KV cache offset management (line 1119-1121) that discards cache for last (parallel_draft_step - 1) tokens
The complexity is inherent to the parallel draft algorithm.
Based on learnings
1136-1157
: LGTM: Accuracy reporting properly implemented.The accuracy reporting:
- Correctly gates on eagle_report_acc and inference mode
- Handles draft vocab size mapping
- Only prints from TP rank 0 to avoid duplicates
- Calculates top-1 accuracy for each parallel draft position
1407-1423
: LGTM: Sequence parallel handling is correct but complex.The gather/scatter pattern for sequence parallelism with the dummy token replacement is correctly implemented. The complexity is inherent to supporting both sequence parallelism and parallel draft functionality.
545-545
: No action needed - code is compatible with project requirements.The project's
setup.py
specifiespython_requires=">=3.10,<3.13"
, which fully supports PEP 604 union syntax. Additionally, the codebase already extensively uses this syntax throughout test files and examples, confirming it is an established pattern. Theinference_context: StaticInferenceContext | None
annotation aligns with both the Python version requirement and existing code conventions.
What does this PR do?
Type of change: New feature
Overview:
This PR implements kv cache in eagle training that significantly reduce memory consumption so that enables longer sequence length.
This PR also contains a prototype of parallel draft implementation. However, the training does not converge. I will leave it as a future work.
Usage
# Add a code snippet demonstrating how to use this
Testing
Passed regression test.
Before your PR is "Ready for review"
Additional Information
Summary by CodeRabbit
New Features
Refactor
Documentation
Chores