-
Notifications
You must be signed in to change notification settings - Fork 169
EAGLE parallel draft with auto regression; kv cache in EAGLE training #391
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
WalkthroughThreads a StaticInferenceContext through Eagle and top-level forward paths; replaces fixed multi-block branching with loop-driven per-draft input, attention-mask and kv-cache handling; generalizes sequence-parallel shapes and per-step loss/Top‑1 accumulation; changes rank-consistency check default to warn (no longer failing by default). Changes
Sequence Diagram(s)sequenceDiagram
autonumber
actor Caller
participant Model
participant Eagle as EagleModule
participant Out as OutputLayer
participant Ctx as StaticInferenceContext
Caller->>Model: forward(..., inference_context=Ctx)
activate Model
Note over Model: loop per ttt_step / parallel_draft_step\nassemble inputs and masks
loop per draft step
Model->>Eagle: forward(inputs..., inference_context=Ctx)
activate Eagle
Eagle-->>Model: hidden_states, pre_logits
deactivate Eagle
Model->>Out: compute_logits(hidden_states)
Out-->>Model: logits_draft
Model->>Ctx: update sequence_len_offset / kv-cache
Model->>Model: accumulate loss / Top‑1 per draft
end
Model-->>Caller: outputs (logits, metrics)
deactivate Model
sequenceDiagram
autonumber
actor Caller
participant PSG as pseudo_speculative_generate
participant Eagle as EagleModule
participant Ctx as StaticInferenceContext
Caller->>PSG: run(..., inference_context=Ctx)
activate PSG
loop for each draft step
PSG->>Eagle: forward(draft_input, inference_context=Ctx)
Eagle-->>PSG: draft_logits, hidden_states
PSG->>PSG: select draft_token
PSG->>Ctx: update sequence_len_offset / kv-cache note
PSG->>PSG: gather/concat hidden_states across drafts
end
PSG-->>Caller: drafted tokens and states
deactivate PSG
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
modelopt/torch/speculative/plugins/megatron_eagle.py (1)
762-835
: Critical: define mask_token_ attributes and revise input masking*
- No
mask_token_{i}
attributes found onself
, causing anAttributeError
. Addmask_token_0
,mask_token_1
, … in__init__
(or pull mask IDs from config) sogetattr(self, f"mask_token_{…}")
succeeds.- The current logic for
parallel_draft_step > 1
replaces all positions with the mask token, overwriting base tokens. Instead, clonepadded_input_ids
and mask only the new draft slot.- (Optional) Shift
position_ids
when you left-shiftinput_ids
to keep positional embeddings aligned.- (Optional) For
ttt_step == 1
, slicegathered_hidden_states[-s:]
to match feature length and avoid mismatched sequence lengths.--- a/modelopt/torch/speculative/plugins/megatron_eagle.py +++ b/modelopt/torch/speculative/plugins/megatron_eagle.py @@ line 780 - eagle_inputs["input_ids"] = ( - padded_input_ids - if parallel_draft_step == 1 - else torch.full( - padded_input_ids.shape, - getattr(self, f"mask_token_{parallel_draft_step - 2}"), - device=padded_input_ids.device, - dtype=padded_input_ids.dtype, - ) - ) + if parallel_draft_step == 1: + eagle_inputs["input_ids"] = padded_input_ids + else: + eagle_inputs["input_ids"] = padded_input_ids.clone() + mask_tok = getattr(self, f"mask_token_{parallel_draft_step - 2}") + eagle_inputs["input_ids"][:, -1] = mask_tok
🧹 Nitpick comments (8)
modelopt/torch/speculative/utils.py (1)
295-312
: Update docstring to reflect warning behaviorThe docstring still claims we throw on divergence, but the new default only emits a warning unless
fail_when_mismatch=True
. Please align the comment with the current behavior so downstream users don’t expect an exception by default.- Use rank 0 data as the golden set to broadcast to all ranks. - Each rank will then compare to this data and through error if different. + Use rank 0 data as the golden set to broadcast to all ranks. + Each rank compares its data against this golden set and either raises + (when fail_when_mismatch=True) or emits a warning while forcing every + rank to adopt rank 0's data.modelopt/torch/speculative/plugins/megatron_eagle.py (7)
198-253
: Multi-step mask: ensure rectangular q×k compatibility and avoid built-in shadowing.
- iter shadows Python’s built-in; rename to step_idx.
- When called repeatedly (ttt_step > 1), s is captured once from the original k_len; OK for fixed-width appends, but brittle if upstream passes a pre-rectangular mask. Consider deriving q_len, k_len each loop and asserting q_len stays constant.
- Minor perf: replace the per-position loop setting mask_1[:, :, i, i] with vectorized indexing.
Apply:
- s = attn_mask.shape[-1] - for iter in range(2, step + 1): + s = attn_mask.shape[-1] # base stride to append each round + q = attn_mask.shape[-2] + for step_idx in range(2, step + 1): # iter starts from 2nd step - mask_0 = attn_mask.clone().detach() - mask_0[:, :, iter - 2, :] = True + mask_0 = attn_mask.clone() + mask_0[:, :, step_idx - 2, :] = True mask_0[:, :, :, :-1] = mask_0[:, :, :, 1:] - mask_1 = attn_mask.new_ones(attn_mask.shape[0], attn_mask.shape[1], s, s).bool() - for i in range(iter - 1, s - 1): - mask_1[:, :, i, i] = False + mask_1 = attn_mask.new_ones(attn_mask.shape[0], attn_mask.shape[1], q, s, dtype=torch.bool) + diag_idx = torch.arange(step_idx - 1, min(q, s) - 1, device=attn_mask.device) + mask_1[:, :, diag_idx, diag_idx] = False attn_mask = torch.cat((mask_0, mask_1), dim=-1)
512-556
: Threading inference_context into EagleModule is fine; verify downstream acceptance.Passing inference_context to self.decoder aligns with KV-cache intent. Ensure the underlying TransformerBlock.forward supports it to avoid unexpected kwarg errors at runtime. Also, decoder_input_list is created but unused beyond [0]; remove the list.
Use the previous script to check TransformerBlock.forward signature acceptance of inference_context.
826-835
: RoPE replication can be large; prefer computing once inside EagleModule.Repeating rotary_pos_emb via cat for each (ttt_step, parallel_draft_step) increases memory. Since EagleModule can compute RoPE when None, consider passing None here (like in pseudo_speculative_generate) and letting the module compute the correct length based on attention_mask.
1014-1018
: StaticInferenceContext capacity heuristic.The 4× multiplier assumes up to 4 ttt_steps. If this becomes configurable, derive max_tokens = input_ids.size(1) * parallel_draft_step * max_ttt_step from config/kwargs rather than hard-coding 4.
1049-1071
: Loss/acc blocks are duplicated; collapse into a single loop.Four near-identical sections differ only by ttt_step index, features input, and decay power. Refactor into a for t in range(1, max_ttt_step+1) that:
- builds eagle_inputs_t with the previous pre-norm features
- runs _eagle_forward
- slices/accumulates loss with decay**t
- optionally reports accuracy
This reduces maintenance and indexing mistakes.
Also applies to: 1119-1142, 1170-1192, 1220-1242
1449-1451
: KV cache note in pseudo-speculate: clarify behavior.Comment states kv cache unsupported when sequence parallel may be enabled. Consider a runtime check raising a clear error if inference_context is mistakenly passed here, or document the precise conditions (sequence_parallel and/or EP).
142-145
: Docstring vs comparator mismatch.mcore_version_higher_than doc says “at least this version” but uses strict >. Either change the docstring to “greater than” or switch to >= to avoid confusion.
-def mcore_version_higher_than(target_version: str): - """Check if megatron-core is least this version.""" - return Version(megatron.core.__version__) > Version(target_version) +def mcore_version_higher_than(target_version: str): + """Check if megatron-core is greater than this version.""" + return Version(megatron.core.__version__) > Version(target_version)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
modelopt/torch/speculative/plugins/megatron_eagle.py
(19 hunks)modelopt/torch/speculative/utils.py
(2 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
modelopt/torch/speculative/utils.py (2)
modelopt/torch/speculative/plugins/megatron_eagle.py (1)
pseudo_speculative_generate
(1442-1560)modelopt/torch/speculative/plugins/transformers.py (1)
pseudo_speculative_generate
(1064-1141)
modelopt/torch/speculative/plugins/megatron_eagle.py (1)
modelopt/torch/speculative/plugins/transformers.py (2)
_get_eagle_module_inputs
(469-516)_eagle_forward
(727-750)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
- GitHub Check: linux
- GitHub Check: wait-checks / wait
- GitHub Check: wait-checks / wait
- GitHub Check: build-docs
- GitHub Check: code-quality
🔇 Additional comments (4)
modelopt/torch/speculative/plugins/megatron_eagle.py (4)
933-960
: sequence_len_offset update basis.You increment inference_context.sequence_len_offset by eagle_inputs["input_ids"].shape[1] (s) each call. Given set_multi_step_attention_mask appends s keys per call, this matches k-len growth. If q_len changes or future refactors alter append width, this will skew. Safer: derive delta from attention_mask.size(-1) change or keep a local appended_k counter.
1102-1111
: Top-1 mapping and slicing: verify offsets and bounds.
- gathered_logits = gathered_logits[i:-1] (and i+1:-1, etc.) assumes enough length; add asserts or clamp to avoid negative-length slices for short sequences.
- For reduced vocab, eagle_top1 += d2t[eagle_top1] relies on d2t storing offsets; confirm invariant (offset mapping) holds across all training configs.
Also applies to: 1149-1161, 1199-1211, 1249-1261
1499-1504
: Parallel draft token scaffolding correctness.You append parallel_draft_step-1 mask tokens before each EAGLE call and later append the last parallel_draft_step hidden states. Validate that (i) mask tokens align with the same positions EAGLE writes logits for, and (ii) hidden_states growth matches logits selection using [-self.eagle_config.parallel_draft_step:]. Add asserts to catch misalignment.
Also applies to: 1551-1556
29-29
: Guard or pin megatron-core version for StaticInferenceContext
Wrap the import ofStaticInferenceContext
in a try/except (or enforce a minimummegatron-core
dependency) and only passinference_context
whenTransformerBlock.forward
supports it—manually verify your targetmegatron-core
version providesStaticInferenceContext
and accepts theinference_context
parameter.
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #391 +/- ##
=======================================
Coverage 73.79% 73.79%
=======================================
Files 171 171
Lines 17591 17591
=======================================
Hits 12982 12982
Misses 4609 4609 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
modelopt/torch/speculative/plugins/megatron_eagle.py (1)
198-252
: Refactor set_multi_step_attention_mask to vectorized rectangular mask
Current iterative mask construction both overwrites an entire query row each loop and omits the final diagonal unmask, breaking causal semantics and leaving the last token unable to attend to new keys. Replace the loop with a single allocation of shape[B, H, S, S*step]
, copy the original causal mask into block 0, then unmask only the diagonal entries in each extra block:def set_multi_step_attention_mask(attn_mask, step): - s = attn_mask.shape[-1] - for iter in range(2, step + 1): - # iter starts from 2nd step - mask_0 = attn_mask.clone().detach() - mask_0[:, :, iter - 2, :] = True - mask_0[:, :, :, :-1] = mask_0[:, :, :, 1:] - mask_1 = attn_mask.new_ones(attn_mask.shape[0], attn_mask.shape[1], s, s).bool() - for i in range(iter - 1, s - 1): - mask_1[:, :, i, i] = False - - attn_mask = torch.cat((mask_0, mask_1), dim=-1) - - return attn_mask + b, h, s, _ = attn_mask.shape + if step <= 1: + return attn_mask + new_mask = attn_mask.new_ones(b, h, s, s * step).bool() + # block-0: original causal mask + new_mask[:, :, :, :s] = attn_mask + # extra blocks: only diagonal unmasked + idx = torch.arange(s, device=attn_mask.device) + for j in range(1, step): + new_mask[:, :, idx, j * s + idx] = False + return new_mask.contiguous()Add unit tests for
s=4
andstep=1..4
to verify both the base causal block and each diagonal-only extension.
🧹 Nitpick comments (6)
modelopt/torch/speculative/plugins/megatron_eagle.py (6)
1014-1018
: Avoid magic constant ‘4’ in StaticInferenceContext capacity
StaticInferenceContext(bs, seq_len * parallel_draft_step * 4)
bakes in 4 rounds. Prefer deriving from actual number of EAGLE rounds executed (currently 4) so changes don’t silently under/over‑allocate.Example:
- eagle_inference_context = StaticInferenceContext( - input_ids.shape[0], input_ids.shape[1] * self.eagle_config.parallel_draft_step * 4 - ) + num_rounds = 4 # current training rounds; keep in one place + max_len = input_ids.shape[1] * self.eagle_config.parallel_draft_step * num_rounds + eagle_inference_context = StaticInferenceContext(input_ids.shape[0], max_len)Even better: centralize
num_rounds
as a config or class attribute.
1049-1071
: Round‑1 logits/pre‑norm handling: collect features or document intent
eagle_hidden_states_0_pre_norm
is overwritten each iteration; only the last parallel slice feeds round‑2 viafeatures
. If intentional, please comment; if not, collect all pre‑norms and pass the concatenation.Potential fix:
- eagle_logits_0 = [] + eagle_logits_0, pre_norm_list = [], [] @@ - _, eagle_logits_, eagle_hidden_states_0_pre_norm = self._eagle_forward( + _, eagle_logits_, eagle_hidden_states_0_pre_norm = self._eagle_forward( @@ - eagle_logits_0.append(eagle_logits_) + eagle_logits_0.append(eagle_logits_) + pre_norm_list.append(eagle_hidden_states_0_pre_norm)Then use
features=torch.cat(pre_norm_list, dim=0)
(adjust downstream slicing accordingly).
1120-1161
: Rounds 2 logic mirrors Round 1: consider DRY loop + recheck feature feedthroughThese blocks duplicate Round‑1 with different offsets. Consider a single loop over
ttt_step in [1..num_rounds]
that:
- builds inputs,
- forwards EAGLE with shared
eagle_inference_context
,- accumulates logits/loss/acc with computed offsets.
This reduces risk of indexing drift and concentrates the offset math in one place.
1170-1211
: Same as above for Round 3Consolidate into a generic per‑round loop; reduces maintenance surface and makes mask/offset reasoning easier.
1220-1261
: Same as above for Round 4; also check acc slice windowsUnify into loop; confirm slices
i+3:-1
vs labels[:, i+4:]
as intended.
1551-1556
: Hidden state update matches draft expansionConcatenating the last
parallel_draft_step
pre‑norm states aligns with the token appends. Add a brief comment to document this coupling.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
modelopt/torch/speculative/plugins/megatron_eagle.py
(19 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
modelopt/torch/speculative/plugins/megatron_eagle.py (1)
modelopt/torch/speculative/plugins/transformers.py (2)
_get_eagle_module_inputs
(493-540)_eagle_forward
(624-649)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: linux
- GitHub Check: wait-checks / wait
- GitHub Check: wait-checks / wait
🔇 Additional comments (9)
modelopt/torch/speculative/plugins/megatron_eagle.py (9)
29-29
: LGTM: inference context importImporting StaticInferenceContext is appropriate for KV‑cache support here.
512-556
: Threading inference_context into EagleModule forward: goodPassing inference_context to the decoder is required for KV‑cache. Ensure call sites always pass the same context instance for all per‑round EAGLE calls within a forward.
933-961
: KV‑cache offset update: OK; ensure per‑forward context reuse onlyIncrementing
sequence_len_offset
byinput_ids.shape[1]
per call matches the rectangular mask growth. Just make sure a fresh context is created per model forward and reused across all EAGLE calls within that forward (not shared across microbatches).
1093-1098
: Indexing of shifted loss windows: please verify alignment
loss_ = loss_[:, i:]
and accumulation intoloss[:, i+1:]
assume a specific alignment between base logits and each parallel draft i. Add a brief comment and a unit check for off‑by‑one around sequence boundaries.
1102-1111
: Accuracy slice may drop/shift one token
gathered_logits = gathered_logits[i:-1]
discards the last step. Confirm this matches label slicinglabels[:, i+1:]
. A small test on a toy batch would help avoid off‑by‑one.
1449-1450
: Docstring clarity: goodExplicitly stating no KV‑cache here avoids confusion with the training path.
1537-1543
: Top‑k draft slice looks correctSlicing the last
parallel_draft_step
steps to formdraft_token
is consistent with the appended placeholders.
1499-1504
: mask_token_ buffers confirmed—no changes needed*The
mask_token_{i}
buffers are registered viaself.register_buffer
inmodelopt/torch/speculative/eagle/eagle_model.py
, ensuring they exist on the correct dtype and device at runtime.
755-835
: Verify mask_token accessibility and RoPE generation
- Ensure the
mask_token_{i}
buffers registered onEagleModel
are actually available onself
in the plugin (or reference them viaself.eagle_module
) – e.g. addif parallel_draft_step > 1: assert hasattr(self, f"mask_token_{parallel_draft_step - 2}")- Delegate rotary embedding generation to
EagleModule
by passingrotary_pos_emb = None
here and allowing its internal use ofinference_context
to handle offsets.- Confirm that
set_multi_step_attention_mask(attn_mask, step)
produces the expected rectangular mask shape.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
modelopt/torch/speculative/plugins/megatron_eagle.py (1)
762-795
: Preserve history when masking draft tokens
- In
_get_eagle_module_inputs
, don’t replace the fullinput_ids
whenparallel_draft_step > 1
; clonepadded_input_ids
and overwrite only the rightmostparallel_draft_step – 1
positions with the appropriatemask_token_*
, matching the behavior inpseudo_speculative_generate
.- No
mask_token_{i}
attributes are defined onself
, which will raiseAttributeError
; add those buffers (e.g. in__init__
) or guard against their absence.
🧹 Nitpick comments (4)
modelopt/torch/speculative/plugins/megatron_eagle.py (4)
933-961
: KV-cache offset management: OK but side-effect needs guard on overflow.You bump sequence_len_offset by input_ids.shape[1] each call. Add a max-cap guard to prevent overflow beyond StaticInferenceContext capacity.
- if inference_context is not None: - inference_context.sequence_len_offset += eagle_inputs["input_ids"].shape[1] + if inference_context is not None: + advance = eagle_inputs["input_ids"].shape[1] + inference_context.sequence_len_offset = min( + inference_context.sequence_len_offset + advance, + inference_context.max_seq_len, + )What is the exact field name for capacity in StaticInferenceContext? Adjust accordingly.
1014-1018
: Hard-coded “×4” capacity couples KV-cache to 4 rounds; derive from schedule.If ttt rounds change, this under/over-allocates. Compute from the actual training schedule:
- eagle_inference_context = StaticInferenceContext( - input_ids.shape[0], input_ids.shape[1] * self.eagle_config.parallel_draft_step * 4 - ) + max_rounds = 4 # TODO: derive from a config/arg if you plan to vary rounds + max_len = input_ids.shape[1] * self.eagle_config.parallel_draft_step * max_rounds + eagle_inference_context = StaticInferenceContext(input_ids.shape[0], max_len)If you plan to auto-regress >4, promote max_rounds to a config knob.
1499-1504
: Mask token attributes must exist and be on the right device/dtype.Accessing getattr(self, f"mask_token_{i}") without validation can raise. Add a small helper or fallback, and ensure to .to(eagle_ids.device, dtype=eagle_ids.dtype) before assignment.
Also applies to: 1537-1543
1551-1555
: Hidden-state growth: ensure KV/activation memory is bounded.You append next_hidden_states_input slices each step; for large steps this grows O(steps). If steps is bounded small, OK. Otherwise consider reusing a rolling buffer.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
modelopt/torch/speculative/plugins/megatron_eagle.py
(19 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
modelopt/torch/speculative/plugins/megatron_eagle.py (1)
modelopt/torch/speculative/plugins/transformers.py (2)
_get_eagle_module_inputs
(493-540)_eagle_forward
(624-649)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (6)
- GitHub Check: example-tests-pr (llm_ptq)
- GitHub Check: gpu-tests-pr
- GitHub Check: multi-torch (27)
- GitHub Check: multi-torch (26)
- GitHub Check: multi-transformers (min)
- GitHub Check: windows
🔇 Additional comments (4)
modelopt/torch/speculative/plugins/megatron_eagle.py (4)
29-29
: Good addition: inference context import.Importing StaticInferenceContext here makes sense for threading KV-cache state.
512-512
: Threading inference_context is correct.Forwarding the context into decoder is consistent with KV-cache usage.
Confirm that TransformerBlock in your mcore version accepts inference_context; older versions may ignore or raise.
Also applies to: 553-554
826-835
: Mask/pos-emb replication factor looks right; verify rotary length matches key length.You replicate rotary_pos_emb by the effective step count. Ensure decoder expects k_len rotary embeddings equal to attention_mask K dimension.
1093-1098
: Loss windowing looks consistent; verify off-by-one with new concatenated features.Given the concatenation change above, double-check the i-based slicing still aligns with label offsets.
Also applies to: 1143-1148, 1193-1198, 1243-1248
for iter in range(2, step + 1): | ||
# iter starts from 2nd step | ||
mask_0 = attn_mask.clone().detach() | ||
mask_0[:, :, iter - 2, :] = True | ||
mask_0[:, :, :, :-1] = mask_0[:, :, :, 1:] | ||
mask_1 = attn_mask.new_ones(attn_mask.shape[0], attn_mask.shape[1], s, s).bool() | ||
for i in range(iter - 1, s - 1): | ||
mask_1[:, :, i, i] = False | ||
|
||
attn_mask = torch.cat((mask_0, mask_1), dim=-1) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
Multi-step mask: prevent OOB and vectorize; current loop risks indexing past q_len and is slow.
- iter-2 row index will go out of bounds when step > s.
- The inner for-loop over i is O(s) per step; can be vectorized.
Apply bounds + vectorization:
- for iter in range(2, step + 1):
+ s = attn_mask.shape[-1]
+ # Cap steps to avoid indexing beyond query length.
+ for iter in range(2, min(step, s) + 1):
# iter starts from 2nd step
mask_0 = attn_mask.clone().detach()
mask_0[:, :, iter - 2, :] = True
mask_0[:, :, :, :-1] = mask_0[:, :, :, 1:]
- mask_1 = attn_mask.new_ones(attn_mask.shape[0], attn_mask.shape[1], s, s).bool()
- for i in range(iter - 1, s - 1):
- mask_1[:, :, i, i] = False
+ mask_1 = attn_mask.new_ones(attn_mask.shape[0], attn_mask.shape[1], s, s, dtype=torch.bool)
+ idx = torch.arange(iter - 1, s - 1, device=attn_mask.device)
+ mask_1[:, :, idx, idx] = False
Please add a small unit test to exercise step values {1, 2, s, s+3}.
🏁 Script executed:
#!/bin/bash
sed -n '230,270p' modelopt/torch/speculative/plugins/megatron_eagle.py
Length of output: 2355
🏁 Script executed:
#!/bin/bash
sed -n '150,300p' modelopt/torch/speculative/plugins/megatron_eagle.py
Length of output: 9208
Cap multi-step range to sequence length and vectorize mask construction
- Prevent OOB at
mask_0[:, :, iter-2, :]
by capping the loop:# s = attn_mask.shape[-1] # already defined above for iter in range(2, min(step, s) + 1): …
- Replace the per-element diagonal loop with a batched index assignment:
mask_1 = attn_mask.new_ones(batch, heads, s, s, dtype=torch.bool) idx = torch.arange(iter - 1, s - 1, device=attn_mask.device) mask_1[:, :, idx, idx] = False
- Add unit tests covering step = 1, 2, s, s+3 to verify no OOB and correct masking.
🤖 Prompt for AI Agents
In modelopt/torch/speculative/plugins/megatron_eagle.py around lines 242 to 252,
the for-loop can go out-of-bounds when step > sequence length and the mask_1
diagonal is built with a Python loop; change the loop range to cap at the
sequence length (use range(2, min(step, s) + 1)) and replace the per-element
diagonal loop with a vectorized batched index assignment using torch.arange on
the correct device/dtype to set mask_1[:, :, idx, idx] = False, ensure mask_1 is
created with the same device and boolean dtype as attn_mask, and add unit tests
for step values 1, 2, s, and s+3 to assert no OOB and correct mask
shapes/values.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
♻️ Duplicate comments (1)
modelopt/torch/speculative/plugins/megatron_eagle.py (1)
243-249
: Critical: Out-of-bounds access when step > sequence_length.When
step > s
(wheres = attn_mask.shape[-1]
), line 245 will attempt to indexmask_0[:, :, step_idx, :]
withstep_idx >= s
, causing an out-of-bounds error. Additionally, the inner loop at lines 248-249 is O(s) per iteration and should be vectorized.Apply this fix to cap the loop range and vectorize the diagonal assignment:
s = attn_mask.shape[-1] -for step_idx in range(step): +# Cap step_idx to avoid indexing beyond query length +for step_idx in range(min(step, s)): mask_0 = attn_mask.clone().detach() mask_0[:, :, step_idx, :] = True mask_0[:, :, :, :-1] = mask_0[:, :, :, 1:] - mask_1 = attn_mask.new_ones(attn_mask.shape[0], attn_mask.shape[1], s, s).bool() - for i in range(step_idx + 1, s - 1): - mask_1[:, :, i, i] = False + mask_1 = attn_mask.new_ones(attn_mask.shape[0], attn_mask.shape[1], s, s, dtype=torch.bool) + # Vectorize diagonal assignment + if step_idx + 1 < s - 1: + idx = torch.arange(step_idx + 1, s - 1, device=attn_mask.device) + mask_1[:, :, idx, idx] = False attn_mask = torch.cat((mask_0, mask_1), dim=-1)Add unit tests covering
step
values of 1, 2, s, and s+3 to verify correctness.Based on the past review comment, this issue was previously identified but not yet resolved.
🧹 Nitpick comments (1)
modelopt/torch/speculative/plugins/megatron_eagle.py (1)
826-834
: Renameparallel_draft_step
parameter for clarityThe step calculation (
ttt_step * eagle_config.parallel_draft_step + draft_index
) is correct; rename theparallel_draft_step
parameter (and its occurrences and docstring) todraft_index
to distinguish it fromeagle_config.parallel_draft_step
.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
modelopt/torch/speculative/plugins/megatron_eagle.py
(17 hunks)
🧰 Additional context used
🧠 Learnings (2)
📓 Common learnings
Learnt from: yeyu-nvidia
PR: NVIDIA/TensorRT-Model-Optimizer#391
File: modelopt/torch/speculative/plugins/megatron_eagle.py:1049-1071
Timestamp: 2025-09-29T17:40:37.310Z
Learning: In EAGLE parallel draft implementation in modelopt/torch/speculative/plugins/megatron_eagle.py, overwriting eagle_hidden_states_*_pre_norm variables within the per-draft loop (rather than accumulating all draft pre-norm states) is the expected behavior for parallel draft functionality.
📚 Learning: 2025-09-29T17:40:37.310Z
Learnt from: yeyu-nvidia
PR: NVIDIA/TensorRT-Model-Optimizer#391
File: modelopt/torch/speculative/plugins/megatron_eagle.py:1049-1071
Timestamp: 2025-09-29T17:40:37.310Z
Learning: In EAGLE parallel draft implementation in modelopt/torch/speculative/plugins/megatron_eagle.py, overwriting eagle_hidden_states_*_pre_norm variables within the per-draft loop (rather than accumulating all draft pre-norm states) is the expected behavior for parallel draft functionality.
Applied to files:
modelopt/torch/speculative/plugins/megatron_eagle.py
🧬 Code graph analysis (1)
modelopt/torch/speculative/plugins/megatron_eagle.py (1)
modelopt/torch/speculative/plugins/transformers.py (2)
_get_eagle_module_inputs
(493-540)_eagle_forward
(624-649)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
- GitHub Check: linux
- GitHub Check: wait-checks / wait
- GitHub Check: wait-checks / wait
- GitHub Check: code-quality
- GitHub Check: build-docs
🔇 Additional comments (9)
modelopt/torch/speculative/plugins/megatron_eagle.py (9)
29-29
: LGTM: Import addition supports KV cache feature.The
StaticInferenceContext
import is necessary for the new KV cache functionality described in the PR objectives.
143-143
: LGTM: Improved docstring accuracy.The docstring now correctly describes the strict inequality check (
>
).
512-512
: LGTM: Signature addition supports KV cache propagation.The
inference_context
parameter enables KV cache management throughout the EAGLE module forward pass, aligning with the PR objectives.
796-805
: Correct: Feature slicing preserves last s tokens only.The slicing
feature = gathered_features[-s:]
is correct, as previous parallel draft tokens are preserved in the KV cache (as confirmed in past review discussion). The KV cache mechanism maintains the full context across iterations.Based on learnings from prior review.
947-949
: LGTM: KV cache offset correctly updated after each EAGLE call.The
sequence_len_offset
increment ensures the KV cache tracks the current position across multiple EAGLE module invocations, which is essential for parallel draft functionality.
1107-1129
: LGTM: Per-draft accuracy calculation correctly indexes logits and labels.The accuracy computation maintains consistent next-token prediction offsets across drafts:
- Logits sliced from
[i + ttt_step : -1]
- Labels sliced from
[:, i + ttt_step + 1 :]
The 1-position shift correctly aligns predictions with target tokens.
1361-1365
: LGTM: Parallel draft token generation correctly appends mask tokens.The loop appends
parallel_draft_step - 1
mask tokens and corresponding hidden states, which together with the one real token forms a complete set ofparallel_draft_step
positions for EAGLE to predict.
1399-1418
: LGTM: Draft token extraction correctly handles parallel drafts.The slicing
[-self.eagle_config.parallel_draft_step :]
extracts all parallel draft tokens and their corresponding hidden states, maintaining consistency with the parallel draft architecture.
1099-1105
: I’ve extracted the EAGLE loss helper to inspect its behavior. Let me know the output so I can verify the accumulation logic.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
♻️ Duplicate comments (1)
modelopt/torch/speculative/plugins/megatron_eagle.py (1)
242-253
: Fix OOB + in-place aliasing in multi-step mask; vectorize diagonal writes.
- Range loop can index past q_len when step > s; cap to s.
- In-place shift uses overlapping slices; RHS must be cloned to avoid undefined results.
- Build mask_1 with vectorized diagonal assignment; avoid Python loop.
Apply:
- s = attn_mask.shape[-1] - for step_idx in range(step): - mask_0 = attn_mask.clone().detach() - mask_0[:, :, step_idx, :] = True - mask_0[:, :, :, :-1] = mask_0[:, :, :, 1:] - mask_1 = attn_mask.new_ones(attn_mask.shape[0], attn_mask.shape[1], s, s).bool() - for i in range(step_idx + 1, s - 1): - mask_1[:, :, i, i] = False - attn_mask = torch.cat((mask_0, mask_1), dim=-1) + s = attn_mask.shape[-2] # query length + iter_max = min(step, s) + for step_idx in range(iter_max): + mask_0 = attn_mask.clone() + mask_0[:, :, step_idx, :] = True + # Avoid aliasing on overlapping slices + mask_0[:, :, :, :-1] = mask_0[:, :, :, 1:].clone() + # Vectorized diagonal + mask_1 = torch.ones( + (*attn_mask.shape[:2], s, s), dtype=torch.bool, device=attn_mask.device + ) + if step_idx + 1 < s - 1: + idx = torch.arange(step_idx + 1, s - 1, device=attn_mask.device) + mask_1[:, :, idx, idx] = False + attn_mask = torch.cat((mask_0, mask_1), dim=-1)Also consider preallocating mask_1 outside the loop and only rewriting the diagonal slice per iteration.
🧹 Nitpick comments (3)
modelopt/torch/speculative/plugins/megatron_eagle.py (3)
830-835
: Add a shape assertion to ensure rotary_pos_emb matches expanded k_len.Guard against subtle mismatches after repetition.
e.g., +assert ( + eagle_inputs["rotary_pos_emb"].shape[0] == eagle_inputs["attention_mask"].shape[-1] +), "rotary_pos_emb length must equal attention_mask k_len"
1048-1066
: Prefer padding labels with ignore_index instead of 0.Zero may be a valid token id; using it biases loss. Pad with the model’s ignore index (commonly -100).
- right_token_pad = torch.zeros( - (labels.shape[0], 1), - dtype=labels.dtype, - device=labels.device, - ) + ignore_index = getattr(self, "label_ignore", -100) + right_token_pad = torch.full( + (labels.shape[0], 1), fill_value=ignore_index, dtype=labels.dtype, device=labels.device + ) labels = torch.cat((labels, right_token_pad), dim=-1)Please confirm compute_language_model_loss uses ignore_index consistently.
947-950
: Sequence length accounting in KV cache: add a safety check.Incrementing by input_ids.shape[1] per call assumes exactly s new KV entries each call. Add a guard to avoid exceeding the preallocated max.
if inference_context is not None: - inference_context.sequence_len_offset += eagle_inputs["input_ids"].shape[1] + inference_context.sequence_len_offset += eagle_inputs["input_ids"].shape[1] + assert ( + inference_context.sequence_len_offset + <= inference_context.max_sequence_length + ), "StaticInferenceContext capacity exceeded"Also applies to: 1015-1019
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
modelopt/torch/speculative/plugins/megatron_eagle.py
(17 hunks)
🧰 Additional context used
🧠 Learnings (2)
📓 Common learnings
Learnt from: yeyu-nvidia
PR: NVIDIA/TensorRT-Model-Optimizer#391
File: modelopt/torch/speculative/plugins/megatron_eagle.py:1049-1071
Timestamp: 2025-09-29T17:40:37.310Z
Learning: In EAGLE parallel draft implementation in modelopt/torch/speculative/plugins/megatron_eagle.py, overwriting eagle_hidden_states_*_pre_norm variables within the per-draft loop (rather than accumulating all draft pre-norm states) is the expected behavior for parallel draft functionality.
📚 Learning: 2025-09-29T17:40:37.310Z
Learnt from: yeyu-nvidia
PR: NVIDIA/TensorRT-Model-Optimizer#391
File: modelopt/torch/speculative/plugins/megatron_eagle.py:1049-1071
Timestamp: 2025-09-29T17:40:37.310Z
Learning: In EAGLE parallel draft implementation in modelopt/torch/speculative/plugins/megatron_eagle.py, overwriting eagle_hidden_states_*_pre_norm variables within the per-draft loop (rather than accumulating all draft pre-norm states) is the expected behavior for parallel draft functionality.
Applied to files:
modelopt/torch/speculative/plugins/megatron_eagle.py
🧬 Code graph analysis (1)
modelopt/torch/speculative/plugins/megatron_eagle.py (1)
modelopt/torch/speculative/plugins/transformers.py (2)
_get_eagle_module_inputs
(493-540)_eagle_forward
(624-649)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
- GitHub Check: linux
- GitHub Check: wait-checks / wait
- GitHub Check: wait-checks / wait
- GitHub Check: code-quality
- GitHub Check: build-docs
🔇 Additional comments (4)
modelopt/torch/speculative/plugins/megatron_eagle.py (4)
796-806
: Feature slicing logic is correct with KV cache; keep last s only.Using only the last s features for the next ttt_step is expected because previous drafts are preserved in the KV cache; the slice
gathered_features[-s:]
is correct.Based on learnings
Also applies to: 806-819, 821-825
1067-1070
: Pre‑norm overwrite per draft is intentional; no change needed.Overwriting eagle_hidden_states_*_pre_norm inside the per‑draft loop is expected for parallel draft.
Based on learnings
Also applies to: 1092-1092
1311-1312
: KV cache deliberately disabled in pseudo_speculative_generate; good.Not passing inference_context here avoids cache/sequence‑parallel interplay; the guard comment is clear.
Also applies to: 1387-1390
29-29
: Guard StaticInferenceContext import and conditional inference_context usage for backward compatibility. Verify the minimum megatron-core version that introducesStaticInferenceContext
and adds theinference_context
parameter toTransformerBlock.forward
; if you must support earlier versions, apply the import guard and conditional kwarg wiring as shown above in all relevant blocks (lines 512–515, 553, 933–944, 1015–1019).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
♻️ Duplicate comments (1)
modelopt/torch/speculative/plugins/megatron_eagle.py (1)
242-252
: Prevent potential OOB and consider vectorization.The loop iterates
for step_idx in range(step)
and indexesmask_0[:, :, step_idx, :]
at line 245. Ifstep
exceedss
(the sequence length), this will cause an out-of-bounds error. Additionally, the inner loop constructing the diagonal mask (lines 248-249) could be vectorized for better performance.Apply this defensive cap and vectorization:
s = attn_mask.shape[-1] - for step_idx in range(step): + # Cap step to prevent OOB indexing + for step_idx in range(min(step, s)): mask_0 = attn_mask.clone().detach() mask_0[:, :, step_idx, :] = True mask_0[:, :, :, :-1] = mask_0[:, :, :, 1:] mask_1 = attn_mask.new_ones(attn_mask.shape[0], attn_mask.shape[1], s, s).bool() - for i in range(step_idx + 1, s - 1): - mask_1[:, :, i, i] = False + # Vectorize diagonal assignment + if step_idx + 1 < s - 1: + idx = torch.arange(step_idx + 1, s - 1, device=attn_mask.device) + mask_1[:, :, idx, idx] = False attn_mask = torch.cat((mask_0, mask_1), dim=-1)Note: This addresses the concern raised in previous reviews about OOB and performance.
🧹 Nitpick comments (1)
modelopt/torch/speculative/plugins/megatron_eagle.py (1)
142-144
: Clarify version comparison semantics.The function name
mcore_version_higher_than
and implementation using>
indicate a strict comparison, but "greater than" can be ambiguous. Consider updating the docstring to explicitly state "strictly greater than" for clarity.- """Check if megatron-core is greater than this version.""" + """Check if megatron-core version is strictly greater than the target version."""
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
modelopt/torch/speculative/plugins/megatron_eagle.py
(17 hunks)
🧰 Additional context used
🧠 Learnings (2)
📓 Common learnings
Learnt from: yeyu-nvidia
PR: NVIDIA/TensorRT-Model-Optimizer#391
File: modelopt/torch/speculative/plugins/megatron_eagle.py:1049-1071
Timestamp: 2025-09-29T17:40:37.310Z
Learning: In EAGLE parallel draft implementation in modelopt/torch/speculative/plugins/megatron_eagle.py, overwriting eagle_hidden_states_*_pre_norm variables within the per-draft loop (rather than accumulating all draft pre-norm states) is the expected behavior for parallel draft functionality.
📚 Learning: 2025-09-29T17:40:37.310Z
Learnt from: yeyu-nvidia
PR: NVIDIA/TensorRT-Model-Optimizer#391
File: modelopt/torch/speculative/plugins/megatron_eagle.py:1049-1071
Timestamp: 2025-09-29T17:40:37.310Z
Learning: In EAGLE parallel draft implementation in modelopt/torch/speculative/plugins/megatron_eagle.py, overwriting eagle_hidden_states_*_pre_norm variables within the per-draft loop (rather than accumulating all draft pre-norm states) is the expected behavior for parallel draft functionality.
Applied to files:
modelopt/torch/speculative/plugins/megatron_eagle.py
🧬 Code graph analysis (1)
modelopt/torch/speculative/plugins/megatron_eagle.py (1)
modelopt/torch/speculative/plugins/transformers.py (2)
_get_eagle_module_inputs
(493-540)_eagle_forward
(624-649)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: wait-checks / wait
- GitHub Check: linux
- GitHub Check: wait-checks / wait
🔇 Additional comments (8)
modelopt/torch/speculative/plugins/megatron_eagle.py (8)
947-950
: LGTM: Proper KV cache offset tracking.The
sequence_len_offset
update after each EAGLE module call correctly maintains the KV cache index for incremental generation. This aligns with the PR objective of adding KV cache support in EAGLE training.
796-825
: LGTM: Proper sequence parallel handling.The gather/scatter pattern correctly handles sequence parallelism, and the feature shifting logic for
ttt_step > 0
properly implements the autoregressive pattern for parallel draft generation.
1107-1129
: LGTM: Per-draft accuracy reporting properly implemented.The accuracy computation correctly handles per-draft evaluation with proper vocab mapping and index alignment. The reporting provides useful per-draft metrics for monitoring training progress.
1360-1424
: LGTM: Parallel draft generation properly implemented.The restructured loop correctly handles parallel draft token generation:
- Prepares input by appending mask tokens for parallel drafts (lines 1361-1365)
- Selects multiple draft tokens from the output (lines 1399-1405)
- Properly cleans up temporary mask tokens before next iteration (lines 1411-1414)
- Maintains correct hidden state sequence (lines 1418-1423)
This aligns with the PR objective of supporting parallel draft with autoregression.
1048-1065
: LGTM: Label padding handles offline training edge case.The label padding logic correctly handles the case where offline training labels are one token shorter. The comment appropriately documents the small training error this introduces when
logit_distillation=False
. The conditional loss zeroing for frozen base model is also correct.
1015-1019
: LGTM: KV cache properly sized for parallel drafts.The
StaticInferenceContext
initialization correctly allocates cache capacity for all parallel draft steps and TTT steps. This enables the rectangular attention mask optimization mentioned in the PR objectives, reducing memory consumption and avoiding redundant computation.
785-794
: No action needed on mask tokens. Initialization ineagle_model.py
covers allmask_token_{i}
used by the plugin.
1099-1105
: Loss indexing and decay weighting align correctly with autoregressive offsets; no changes needed.
fd1fb15
to
c02a99d
Compare
…ach ttt step, only the non_parallel tokens from previous ttt are used as context Signed-off-by: Ye Yu <[email protected]>
c02a99d
to
b47928e
Compare
Signed-off-by: Ye Yu <[email protected]>
Signed-off-by: Ye Yu <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
🧹 Nitpick comments (1)
modelopt/torch/speculative/plugins/megatron_eagle.py (1)
828-841
: Potential None dereference in gathered_featuresLine 838 accesses
gathered_features[:-1, :, :]
whenttt_step > 0
, but this assumesgathered_features
is not None. While the caller may ensure this, adding a defensive check would improve robustness.Consider adding an assertion or guard:
eagle_inputs["hidden_states"] = ( gathered_hidden_states if ttt_step == 0 else torch.cat( ( torch.zeros( (1, b, h), dtype=hidden_states.dtype, device=hidden_states.device, ), + # gathered_features should not be None when ttt_step > 0 gathered_features[:-1, :, :], # type: ignore[index] ) ) )
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
modelopt/torch/speculative/plugins/megatron_eagle.py
(16 hunks)
🧰 Additional context used
🧠 Learnings (2)
📓 Common learnings
Learnt from: yeyu-nvidia
PR: NVIDIA/TensorRT-Model-Optimizer#391
File: modelopt/torch/speculative/plugins/megatron_eagle.py:1049-1071
Timestamp: 2025-09-29T17:40:37.310Z
Learning: In EAGLE parallel draft implementation in modelopt/torch/speculative/plugins/megatron_eagle.py, overwriting eagle_hidden_states_*_pre_norm variables within the per-draft loop (rather than accumulating all draft pre-norm states) is the expected behavior for parallel draft functionality.
📚 Learning: 2025-09-29T17:40:37.310Z
Learnt from: yeyu-nvidia
PR: NVIDIA/TensorRT-Model-Optimizer#391
File: modelopt/torch/speculative/plugins/megatron_eagle.py:1049-1071
Timestamp: 2025-09-29T17:40:37.310Z
Learning: In EAGLE parallel draft implementation in modelopt/torch/speculative/plugins/megatron_eagle.py, overwriting eagle_hidden_states_*_pre_norm variables within the per-draft loop (rather than accumulating all draft pre-norm states) is the expected behavior for parallel draft functionality.
Applied to files:
modelopt/torch/speculative/plugins/megatron_eagle.py
🧬 Code graph analysis (1)
modelopt/torch/speculative/plugins/megatron_eagle.py (1)
modelopt/torch/speculative/plugins/transformers.py (2)
_get_eagle_module_inputs
(501-548)_eagle_forward
(632-657)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
- GitHub Check: linux
- GitHub Check: wait-checks / wait
- GitHub Check: wait-checks / wait
- GitHub Check: build-docs
- GitHub Check: code-quality
🔇 Additional comments (7)
modelopt/torch/speculative/plugins/megatron_eagle.py (7)
29-29
: LGTM: StaticInferenceContext importThe import is correctly added to support kv cache management in EAGLE training and inference.
143-143
: LGTM: Documentation clarificationThe docstring update clarifies the version comparison semantics.
948-982
: LGTM: Proper inference_context handlingThe changes correctly:
- Thread
inference_context
through the EAGLE forward pass- Update
sequence_len_offset
after each EAGLE module call (lines 968-970) to maintain kv cache consistency- Return all necessary hidden states for the next iteration
1090-1125
: LGTM: Nested loop structure for parallel draftThe nested loop structure correctly implements the parallel draft with auto-regression:
- Outer loop (line 1090) iterates over
ttt_steps
- Inner loop (line 1092) iterates over
parallel_draft_step
to generate multiple tokens per step- Line 1112-1113 correctly preserves only the first draft's pre-norm states (per learnings)
- Lines 1119-1124 properly adjust
sequence_len_offset
to discard intermediate drafts' kv cache entriesBased on learnings
1131-1153
: LGTM: Loss and accuracy computationThe loss computation and accuracy reporting correctly:
- Apply loss decay factor:
eagle_loss_decay_factor ** (ttt_step + i)
(line 1136)- Handle loss offsets properly for each draft position (lines 1134-1137)
- Map draft tokens to target vocabulary when needed (line 1148)
- Compute top-1 accuracy for each parallel draft (lines 1139-1153)
1342-1343
: LGTM: Clarified kv cache constraintThe updated docstring correctly documents that kv cache is not supported when sequence parallel is enabled.
1392-1454
: LGTM: Parallel draft token generationThe implementation correctly handles parallel draft generation:
- Lines 1392-1396: Add
parallel_draft_step - 1
mask tokens and duplicate hidden states- Lines 1430-1436: Extract
parallel_draft_step
draft tokens from logits- Lines 1442-1445: Clean up mask tokens before adding real draft tokens
- Lines 1449-1453: Concatenate appropriate hidden states for the next iteration
The logic is consistent and aligns with the parallel draft mechanism.
for step_idx in range(step): | ||
mask_0 = attn_mask.clone().detach() | ||
mask_0[:, :, step_idx, :] = True | ||
mask_0[:, :, :, :-1] = mask_0[:, :, :, 1:] | ||
mask_1 = attn_mask.new_ones(attn_mask.shape[0], attn_mask.shape[1], s, s).bool() | ||
for i in range(step_idx + 1, s - 1): | ||
mask_1[:, :, i, i] = False | ||
|
||
attn_mask = torch.cat((mask_0, mask_1), dim=-1) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion | 🟠 Major
🧩 Analysis chain
Verify bounds and consider vectorization
The loop construction may encounter issues:
- Line 272: When
step_idx >= s - 1
, the rangerange(step_idx + 1, s - 1)
becomes empty, which is safe but may indicate unexpected behavior for largestep
values. - The inner loop at line 272-273 is O(s) per step and could be vectorized for efficiency.
Consider:
- Adding validation that
step
is within expected bounds relative to sequence length - Vectorizing the diagonal assignment using
torch.arange
Based on past review comments, this was previously flagged. Please verify the current implementation handles all edge cases correctly, especially when step > s
.
🏁 Script executed:
#!/bin/bash
# Search for calls to set_multi_step_attention_mask to understand expected step values
rg -n -A 3 -B 3 'set_multi_step_attention_mask' --type=py
Length of output: 1221
🏁 Script executed:
#!/bin/bash
sed -n '190,260p' modelopt/torch/speculative/plugins/megatron_eagle.py
sed -n '260,340p' modelopt/torch/speculative/plugins/megatron_eagle.py
Length of output: 10130
Validate step
bounds and vectorize diagonal mask
- In
set_multi_step_attention_mask
(modelopt/torch/speculative/plugins/megatron_eagle.py), add at the top:s = attn_mask.shape[-1] if step > s: raise ValueError(f"step ({step}) must be ≤ sequence length ({s})")
- Replace the inner
for i in range(step_idx+1, s-1): mask_1[..., i, i] = False
loop with a vectorized assignment, e.g.:idx = torch.arange(step_idx+1, s-1, device=attn_mask.device) mask_1[..., idx, idx] = False
🤖 Prompt for AI Agents
In modelopt/torch/speculative/plugins/megatron_eagle.py around lines 267–276,
validate that step does not exceed the sequence length by reading s from
attn_mask.shape[-1] and raising a ValueError if step > s; then replace the
per-index Python loop that clears the diagonal in mask_1 with a vectorized
assignment using a torch.arange of indices on the same device as attn_mask to
set mask_1[..., idx, idx] = False (ensure mask_1 remains boolean and use
attn_mask.device for the arange).
# EAGLE kv cache | ||
eagle_inference_context = StaticInferenceContext( | ||
input_ids.shape[0], | ||
input_ids.shape[1] * (self.eagle_config.parallel_draft_step + ttt_steps - 1), | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
Verify kv cache buffer size calculation
The buffer size is calculated as input_ids.shape[1] * (parallel_draft_step + ttt_steps - 1)
. Let me verify this accounts for all tokens correctly:
- Initial tokens:
input_ids.shape[1]
- Each ttt_step generates:
parallel_draft_step
tokens - But only 1 token per ttt_step is kept for the next iteration
- Total additional tokens: approximately
ttt_steps * parallel_draft_step
The formula (parallel_draft_step + ttt_steps - 1)
seems like it might underestimate when both parameters are large.
Please verify that the buffer size correctly accounts for the maximum number of tokens that will be present in the kv cache at any point during the forward pass.
🏁 Script executed:
#!/bin/bash
# Search for StaticInferenceContext usage and buffer size patterns
rg -n -B 3 -A 3 'StaticInferenceContext' --type=py
Length of output: 4446
Fix kv-cache buffer size calculation
The size passed to StaticInferenceContext at modelopt/torch/speculative/plugins/megatron_eagle.py:1036–1040 is wrong—multiplying the original seq-len by (parallel_draft_step + ttt_steps - 1)
miscomputes the total tokens processed. It should be based on addition, e.g.:
max_seq_len = input_ids.shape[1] + self.eagle_config.parallel_draft_step * ttt_steps
so the KV cache can hold the original tokens plus all generated ones.
🤖 Prompt for AI Agents
In modelopt/torch/speculative/plugins/megatron_eagle.py around lines 1036 to
1040, the KV-cache buffer size is computed incorrectly by multiplying the
original sequence length by (parallel_draft_step + ttt_steps - 1); change the
calculation to compute the maximum sequence length as the original input length
plus all generated tokens: max_seq_len = input_ids.shape[1] +
self.eagle_config.parallel_draft_step * ttt_steps, and pass that value as the
second argument to StaticInferenceContext so the KV cache can hold the original
tokens plus all generated tokens.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, the original value is correct
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@yeyu-nvidia, thank you for the clarification! I see now that the calculation is correct.
Looking at the code flow more carefully, I can see that:
- Each
ttt_step
processesinput_ids.shape[1] * parallel_draft_step
tokens - After each iteration,
sequence_len_offset
is decremented byinput_ids.shape[1] * (parallel_draft_step - 1)
(line 1067-1070) - This means only
input_ids.shape[1]
tokens worth of KV cache is retained per iteration - The total maximum KV cache size needed is therefore:
input_ids.shape[1] * (ttt_steps + parallel_draft_step - 1)
The original formula correctly accounts for the fact that KV cache entries are being reused/discarded during the loop, not simply accumulated. My apologies for the confusion!
✏️ Learnings added
Learnt from: yeyu-nvidia
PR: NVIDIA/TensorRT-Model-Optimizer#391
File: modelopt/torch/speculative/plugins/megatron_eagle.py:1036-1040
Timestamp: 2025-10-01T21:34:30.819Z
Learning: In EAGLE parallel draft with KV cache (modelopt/torch/speculative/plugins/megatron_eagle.py), the StaticInferenceContext buffer size is correctly calculated as `input_ids.shape[1] * (parallel_draft_step + ttt_steps - 1)` because each ttt_step iteration discards (parallel_draft_step - 1) worth of KV entries by decrementing sequence_len_offset, so only one token's worth of KV cache is retained per ttt_step.
Note: Learnings are effective only in the context of similar code segments. To apply general coding guidelines, please configure review instructions. You can manage existing learnings in the app.
🧠 Learnings used
Learnt from: yeyu-nvidia
PR: NVIDIA/TensorRT-Model-Optimizer#391
File: modelopt/torch/speculative/plugins/megatron_eagle.py:1049-1071
Timestamp: 2025-09-29T17:40:37.310Z
Learning: In EAGLE parallel draft implementation in modelopt/torch/speculative/plugins/megatron_eagle.py, overwriting eagle_hidden_states_*_pre_norm variables within the per-draft loop (rather than accumulating all draft pre-norm states) is the expected behavior for parallel draft functionality.
Signed-off-by: Ye Yu <[email protected]>
Signed-off-by: Ye Yu <[email protected]>
Signed-off-by: Ye Yu <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
♻️ Duplicate comments (1)
modelopt/torch/speculative/plugins/megatron_eagle.py (1)
267-276
: Validatestep
parameter and vectorize diagonal mask construction.The loop construction may encounter issues:
- Out-of-bounds risk: When
step > s
(sequence length), line 269mask_0[:, :, step_idx, :]
will index beyond the tensor dimension.- Performance: The inner loop at lines 272-273 is O(s) per step and could be vectorized.
Apply this fix to add bounds validation and vectorize:
def set_multi_step_attention_mask(attn_mask, step): + s = attn_mask.shape[-1] + if step > s: + raise ValueError(f"step ({step}) must be ≤ sequence length ({s})") + - s = attn_mask.shape[-1] for step_idx in range(step): mask_0 = attn_mask.clone().detach() mask_0[:, :, step_idx, :] = True mask_0[:, :, :, :-1] = mask_0[:, :, :, 1:] mask_1 = attn_mask.new_ones(attn_mask.shape[0], attn_mask.shape[1], s, s).bool() - for i in range(step_idx + 1, s - 1): - mask_1[:, :, i, i] = False + if step_idx + 1 < s - 1: + idx = torch.arange(step_idx + 1, s - 1, device=attn_mask.device) + mask_1[:, :, idx, idx] = False attn_mask = torch.cat((mask_0, mask_1), dim=-1)This resolves the concerns from the previous review and adds safety + performance improvements.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
modelopt/torch/speculative/plugins/megatron_eagle.py
(17 hunks)
🧰 Additional context used
🧠 Learnings (3)
📓 Common learnings
Learnt from: yeyu-nvidia
PR: NVIDIA/TensorRT-Model-Optimizer#391
File: modelopt/torch/speculative/plugins/megatron_eagle.py:1036-1040
Timestamp: 2025-10-01T21:34:30.854Z
Learning: In EAGLE parallel draft with KV cache (modelopt/torch/speculative/plugins/megatron_eagle.py), the StaticInferenceContext buffer size is correctly calculated as `input_ids.shape[1] * (parallel_draft_step + ttt_steps - 1)` because each ttt_step iteration discards (parallel_draft_step - 1) worth of KV entries by decrementing sequence_len_offset, so only one token's worth of KV cache is retained per ttt_step.
📚 Learning: 2025-10-01T21:34:30.854Z
Learnt from: yeyu-nvidia
PR: NVIDIA/TensorRT-Model-Optimizer#391
File: modelopt/torch/speculative/plugins/megatron_eagle.py:1036-1040
Timestamp: 2025-10-01T21:34:30.854Z
Learning: In EAGLE parallel draft with KV cache (modelopt/torch/speculative/plugins/megatron_eagle.py), the StaticInferenceContext buffer size is correctly calculated as `input_ids.shape[1] * (parallel_draft_step + ttt_steps - 1)` because each ttt_step iteration discards (parallel_draft_step - 1) worth of KV entries by decrementing sequence_len_offset, so only one token's worth of KV cache is retained per ttt_step.
Applied to files:
modelopt/torch/speculative/plugins/megatron_eagle.py
📚 Learning: 2025-09-29T17:40:37.310Z
Learnt from: yeyu-nvidia
PR: NVIDIA/TensorRT-Model-Optimizer#391
File: modelopt/torch/speculative/plugins/megatron_eagle.py:1049-1071
Timestamp: 2025-09-29T17:40:37.310Z
Learning: In EAGLE parallel draft implementation in modelopt/torch/speculative/plugins/megatron_eagle.py, overwriting eagle_hidden_states_*_pre_norm variables within the per-draft loop (rather than accumulating all draft pre-norm states) is the expected behavior for parallel draft functionality.
Applied to files:
modelopt/torch/speculative/plugins/megatron_eagle.py
🧬 Code graph analysis (1)
modelopt/torch/speculative/plugins/megatron_eagle.py (1)
modelopt/torch/speculative/plugins/transformers.py (2)
_get_eagle_module_inputs
(501-548)_eagle_forward
(632-657)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
- GitHub Check: linux
- GitHub Check: wait-checks / wait
- GitHub Check: wait-checks / wait
- GitHub Check: code-quality
- GitHub Check: build-docs
🔇 Additional comments (8)
modelopt/torch/speculative/plugins/megatron_eagle.py (8)
29-29
: LGTM!The
StaticInferenceContext
import is correctly added to support KV cache functionality in EAGLE training and inference.
143-143
: LGTM!The docstring clarification from "higher than" to "greater than" improves precision.
817-835
: LGTM!The Eagle module input construction correctly handles:
- Conditional input_ids/embeddings based on
parallel_draft_index
- Hidden states selection for first vs. subsequent drafts
- Multi-step attention mask generation
- Rotary position embeddings concatenation for multiple steps
The logic aligns with the parallel draft architecture described in the PR objectives.
921-954
: LGTM!The
_eagle_forward
method correctly:
- Threads
inference_context
through to the eagle_module- Updates
sequence_len_offset
after processing to maintain KV cache consistency- Returns the same values with improved formatting
The KV cache handling is sound and aligns with the learned context about buffer size calculations.
Based on learnings
1009-1013
: LGTM!The KV cache buffer size calculation
input_ids.shape[1] * (self.eagle_config.parallel_draft_step + ttt_steps - 1)
is correct. Eachttt_step
iteration discards(parallel_draft_step - 1)
worth of KV entries (line 1115-1117), so only one token's worth of KV cache is retained perttt_step
.Based on learnings
1042-1155
: LGTM!The forward method training loop correctly implements:
- Label padding for offline training compatibility (lines 1042-1053)
- Nested loops for
ttt_steps
andparallel_draft_step
(lines 1062-1087)- Per-draft Eagle forward passes with KV cache context (lines 1074-1081)
- Preservation of first draft's pre-norm hidden states for next iteration (lines 1083-1084), which is the expected behavior per learnings
- Sequence-parallel gathering/scattering (lines 1089-1110)
- KV cache offset adjustment to discard unused entries (lines 1115-1117)
- Loss accumulation with exponential decay (lines 1124-1130)
- Top-1 accuracy reporting per draft (lines 1132-1146)
The implementation aligns with the PR objectives of parallel draft with auto-regression and efficient KV cache management.
Based on learnings
1385-1391
: Verifymask_token_*
buffer registration.Lines 1387-1388 reference
getattr(self, f"mask_token_{i}")
, which requires the buffers flagged as missing in the earlier review comment (lines 806-815). Ensure those buffers are registered inEagleModule.__init__
before this code can execute correctly.Otherwise, the logic for adding mask tokens and padding hidden states for parallel draft generation is sound.
1405-1465
: LGTM!The
pseudo_speculative_generate
method correctly implements parallel draft generation:
- Proper sequence-parallel handling for embeddings and hidden states (lines 1405-1417)
- Correct replacement of dummy hidden states with embeddings for mask tokens (lines 1409-1415)
- Appropriate extraction of draft tokens from multiple positions (lines 1436-1442)
- Clean removal of mask tokens before concatenating new drafts (lines 1448-1461)
The implementation maintains consistency with the parallel draft architecture and handles tensor operations correctly.
eagle_inputs["input_ids"] = ( | ||
padded_input_ids | ||
if parallel_draft_index == 0 | ||
else torch.full( | ||
padded_input_ids.shape, | ||
getattr(self, f"mask_token_{parallel_draft_index - 1}"), | ||
device=padded_input_ids.device, | ||
dtype=padded_input_ids.dtype, | ||
) | ||
if self.config.sequence_parallel: | ||
eagle_inputs["hidden_states"] = scatter_to_sequence_parallel_region( | ||
eagle_inputs["hidden_states"] | ||
) | ||
|
||
eagle_inputs["attention_mask"] = set_multi_step_attention_mask(attn_mask, 3) | ||
eagle_inputs["position_ids"] = torch.cat( | ||
(position_ids, position_ids, position_ids), dim=-1 | ||
) | ||
|
||
if rotary_pos_emb is not None: | ||
eagle_inputs["rotary_pos_emb"] = torch.cat( | ||
(rotary_pos_emb, rotary_pos_emb, rotary_pos_emb), | ||
dim=0, | ||
) | ||
else: | ||
# [TODO] (yeyu): there will be problem here with MLA | ||
eagle_inputs["rotary_pos_emb"] = None | ||
else: | ||
eagle_inputs["input_ids"] = torch.cat( | ||
(padded_input_ids, padded_input_ids, padded_input_ids, padded_input_ids), | ||
dim=-1, | ||
) | ||
|
||
if self.config.sequence_parallel: | ||
gathered_hidden_states = gather_from_sequence_parallel_region(hidden_states) | ||
gathered_features = gather_from_sequence_parallel_region(features) | ||
else: | ||
gathered_hidden_states = hidden_states | ||
gathered_features = features | ||
eagle_inputs["hidden_states"] = torch.cat( | ||
( | ||
gathered_hidden_states, | ||
torch.zeros((1, b, h), dtype=hidden_states.dtype, device=hidden_states.device), | ||
gathered_features[:-1, :, :], | ||
), | ||
dim=0, | ||
) | ||
if self.config.sequence_parallel: | ||
eagle_inputs["hidden_states"] = scatter_to_sequence_parallel_region( | ||
eagle_inputs["hidden_states"] | ||
) | ||
|
||
eagle_inputs["attention_mask"] = set_multi_step_attention_mask(attn_mask, 4) | ||
eagle_inputs["position_ids"] = torch.cat( | ||
(position_ids, position_ids, position_ids, position_ids), dim=-1 | ||
) | ||
|
||
if rotary_pos_emb is not None: | ||
eagle_inputs["rotary_pos_emb"] = torch.cat( | ||
(rotary_pos_emb, rotary_pos_emb, rotary_pos_emb, rotary_pos_emb), | ||
dim=0, | ||
) | ||
else: | ||
# [TODO] (yeyu): there will be problem here with MLA | ||
eagle_inputs["rotary_pos_emb"] = None | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Register mask_token_*
buffers in EagleModule.__init__
.
Lines 811-814 reference self.mask_token_{parallel_draft_index - 1}
attributes, but these buffers are never registered in the EagleModule.__init__
method (lines 369-482). This will cause an AttributeError
at runtime when parallel_draft_index > 0
.
Add buffer registration in EagleModule.__init__
(after line 482, before the closing of __init__
):
# Register mask token buffers for parallel draft
if self.config.parallel_draft_step > 1:
for i in range(self.config.parallel_draft_step - 1):
# Use a sentinel value like vocab_size or a special token ID
self.register_buffer(
f"mask_token_{i}",
torch.tensor(self.config.vocab_size - 1 - i, dtype=torch.long)
)
Note: Choose an appropriate mask token ID based on your tokenizer (e.g., a special unused token, or vocab_size - 1
).
🤖 Prompt for AI Agents
In modelopt/torch/speculative/plugins/megatron_eagle.py around lines 806-815,
self.mask_token_{i} attributes are referenced but never registered causing
AttributeError when parallel_draft_index > 0; fix by registering those
mask_token_* buffers in EagleModule.__init__ (after line 482) by looping for i
in range(self.config.parallel_draft_step - 1) and calling
self.register_buffer(f"mask_token_{i}", torch.tensor(<chosen_token_id>,
dtype=torch.long)) where <chosen_token_id> is an appropriate sentinel (e.g.
self.config.vocab_size - 1 - i or a dedicated unused token id /
self.config.mask_token_id if available) so the attributes exist as long tensors
on the module.
What does this PR do?
Type of change: ? new feature
Overview:
This PR adds 2 features:
Usage
No API change.
Testing
Offline tested with Qwen3-30B and Llama3.2-1B. Passed nmm-sandbox regression test.
Before your PR is "Ready for review"
Additional Information
Summary by CodeRabbit
New Features
Refactor
Chores