Skip to content

Conversation

yeyu-nvidia
Copy link
Contributor

@yeyu-nvidia yeyu-nvidia commented Sep 29, 2025

What does this PR do?

Type of change: ? new feature

Overview:
This PR adds 2 features:

  1. Integrate parallel draft with auto regression in EAGLE training and inference. Now when parallel_draft_step > 1, multiple tokens are generated in each step in training/inference. All parallel draft tokens will be used as context for the next ttt_step training.
  2. Use kv cache in EAGLE training. The attention mask is no longer a seq-len**2 square mask, but a q-len*k-len rectangle mask. This reduces memory consumption and avoid redundant computation.

Usage

No API change.

Testing

Offline tested with Qwen3-30B and Llama3.2-1B. Passed nmm-sandbox regression test.

Before your PR is "Ready for review"

  • Make sure you read and follow Contributor guidelines and your commits are signed.
  • Is this change backward compatible?: Yes/No
  • Did you write any new necessary tests?: Yes/No
  • Did you add or update any necessary documentation?: Yes/No
  • Did you update Changelog?: Yes/No

Additional Information

Summary by CodeRabbit

  • New Features

    • Inference context can be passed through speculative decoding and top-level generation flows; public entry points accept an optional inference context.
    • Multi-draft generation now supports per-draft/per-step progression with per-step KV/cache updates, cumulative loss, and per-draft accuracy tracking.
  • Refactor

    • Replaced fixed multi-step assembly with loop-based per-step construction for inputs, attention masks, and sequence-parallel handling to simplify draft propagation.
  • Chores

    • Cross-rank consistency checks now warn by default; strict failure is opt-in.

@yeyu-nvidia yeyu-nvidia requested a review from a team as a code owner September 29, 2025 16:51
Copy link

coderabbitai bot commented Sep 29, 2025

Walkthrough

Threads a StaticInferenceContext through Eagle and top-level forward paths; replaces fixed multi-block branching with loop-driven per-draft input, attention-mask and kv-cache handling; generalizes sequence-parallel shapes and per-step loss/Top‑1 accumulation; changes rank-consistency check default to warn (no longer failing by default).

Changes

Cohort / File(s) Summary
EAGLE inference context & per-draft refactor
modelopt/torch/speculative/plugins/megatron_eagle.py
Adds StaticInferenceContext import and threads inference_context through public forwards (EagleModule.forward, _eagle_forward, _get_eagle_module_inputs, top-level model forward). Replaces fixed multi-block mask/concatenation with loop-based per-draft assembly using ttt_step and parallel_draft_step; updates per-draft input construction (input_ids, hidden_states, rotary_pos_emb, attention_mask), sequence-parallel handling, kv-cache/sequence_len_offset updates, and per-step loss/Top‑1 accumulation.
Distributed consistency util default
modelopt/torch/speculative/utils.py
Changes check_data_consistency_across_ranks(..., fail_when_mismatch=False) default from TrueFalse; call sites drop explicit fail_when_mismatch=False. On mismatch the function now warns and returns the golden set unless caller sets fail_when_mismatch=True.

Sequence Diagram(s)

sequenceDiagram
  autonumber
  actor Caller
  participant Model
  participant Eagle as EagleModule
  participant Out as OutputLayer
  participant Ctx as StaticInferenceContext

  Caller->>Model: forward(..., inference_context=Ctx)
  activate Model
  Note over Model: loop per ttt_step / parallel_draft_step\nassemble inputs and masks
  loop per draft step
    Model->>Eagle: forward(inputs..., inference_context=Ctx)
    activate Eagle
    Eagle-->>Model: hidden_states, pre_logits
    deactivate Eagle
    Model->>Out: compute_logits(hidden_states)
    Out-->>Model: logits_draft
    Model->>Ctx: update sequence_len_offset / kv-cache
    Model->>Model: accumulate loss / Top‑1 per draft
  end
  Model-->>Caller: outputs (logits, metrics)
  deactivate Model
Loading
sequenceDiagram
  autonumber
  actor Caller
  participant PSG as pseudo_speculative_generate
  participant Eagle as EagleModule
  participant Ctx as StaticInferenceContext

  Caller->>PSG: run(..., inference_context=Ctx)
  activate PSG
  loop for each draft step
    PSG->>Eagle: forward(draft_input, inference_context=Ctx)
    Eagle-->>PSG: draft_logits, hidden_states
    PSG->>PSG: select draft_token
    PSG->>Ctx: update sequence_len_offset / kv-cache note
    PSG->>PSG: gather/concat hidden_states across drafts
  end
  PSG-->>Caller: drafted tokens and states
  deactivate PSG
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Poem

A rabbit taps keys with speculative cheer,
Loops bloom like clover, drafts hopping near.
Context in paw, offsets all aligned,
Masks weave snug tunnels, caches well‑timed.
Ranks murmur warnings—gentle, not severe. 🥕

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 77.78% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title Check ✅ Passed The title concisely identifies the two main features implemented in the PR—parallel draft with auto-regression and KV cache support in EAGLE training—and aligns with the stated objectives without extraneous detail.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch yeyu/kv_cache

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
modelopt/torch/speculative/plugins/megatron_eagle.py (1)

762-835: Critical: define mask_token_ attributes and revise input masking*

  • No mask_token_{i} attributes found on self, causing an AttributeError. Add mask_token_0, mask_token_1, … in __init__ (or pull mask IDs from config) so getattr(self, f"mask_token_{…}") succeeds.
  • The current logic for parallel_draft_step > 1 replaces all positions with the mask token, overwriting base tokens. Instead, clone padded_input_ids and mask only the new draft slot.
  • (Optional) Shift position_ids when you left-shift input_ids to keep positional embeddings aligned.
  • (Optional) For ttt_step == 1, slice gathered_hidden_states[-s:] to match feature length and avoid mismatched sequence lengths.
--- a/modelopt/torch/speculative/plugins/megatron_eagle.py
+++ b/modelopt/torch/speculative/plugins/megatron_eagle.py
@@ line 780
-        eagle_inputs["input_ids"] = (
-            padded_input_ids
-            if parallel_draft_step == 1
-            else torch.full(
-                padded_input_ids.shape,
-                getattr(self, f"mask_token_{parallel_draft_step - 2}"),
-                device=padded_input_ids.device,
-                dtype=padded_input_ids.dtype,
-            )
-        )
+        if parallel_draft_step == 1:
+            eagle_inputs["input_ids"] = padded_input_ids
+        else:
+            eagle_inputs["input_ids"] = padded_input_ids.clone()
+            mask_tok = getattr(self, f"mask_token_{parallel_draft_step - 2}")
+            eagle_inputs["input_ids"][:, -1] = mask_tok
🧹 Nitpick comments (8)
modelopt/torch/speculative/utils.py (1)

295-312: Update docstring to reflect warning behavior

The docstring still claims we throw on divergence, but the new default only emits a warning unless fail_when_mismatch=True. Please align the comment with the current behavior so downstream users don’t expect an exception by default.

-        Use rank 0 data as the golden set to broadcast to all ranks.
-        Each rank will then compare to this data and through error if different.
+        Use rank 0 data as the golden set to broadcast to all ranks.
+        Each rank compares its data against this golden set and either raises
+        (when fail_when_mismatch=True) or emits a warning while forcing every
+        rank to adopt rank 0's data.
modelopt/torch/speculative/plugins/megatron_eagle.py (7)

198-253: Multi-step mask: ensure rectangular q×k compatibility and avoid built-in shadowing.

  • iter shadows Python’s built-in; rename to step_idx.
  • When called repeatedly (ttt_step > 1), s is captured once from the original k_len; OK for fixed-width appends, but brittle if upstream passes a pre-rectangular mask. Consider deriving q_len, k_len each loop and asserting q_len stays constant.
  • Minor perf: replace the per-position loop setting mask_1[:, :, i, i] with vectorized indexing.

Apply:

-    s = attn_mask.shape[-1]
-    for iter in range(2, step + 1):
+    s = attn_mask.shape[-1]  # base stride to append each round
+    q = attn_mask.shape[-2]
+    for step_idx in range(2, step + 1):
         # iter starts from 2nd step
-        mask_0 = attn_mask.clone().detach()
-        mask_0[:, :, iter - 2, :] = True
+        mask_0 = attn_mask.clone()
+        mask_0[:, :, step_idx - 2, :] = True
         mask_0[:, :, :, :-1] = mask_0[:, :, :, 1:]
-        mask_1 = attn_mask.new_ones(attn_mask.shape[0], attn_mask.shape[1], s, s).bool()
-        for i in range(iter - 1, s - 1):
-            mask_1[:, :, i, i] = False
+        mask_1 = attn_mask.new_ones(attn_mask.shape[0], attn_mask.shape[1], q, s, dtype=torch.bool)
+        diag_idx = torch.arange(step_idx - 1, min(q, s) - 1, device=attn_mask.device)
+        mask_1[:, :, diag_idx, diag_idx] = False
         attn_mask = torch.cat((mask_0, mask_1), dim=-1)

512-556: Threading inference_context into EagleModule is fine; verify downstream acceptance.

Passing inference_context to self.decoder aligns with KV-cache intent. Ensure the underlying TransformerBlock.forward supports it to avoid unexpected kwarg errors at runtime. Also, decoder_input_list is created but unused beyond [0]; remove the list.

Use the previous script to check TransformerBlock.forward signature acceptance of inference_context.


826-835: RoPE replication can be large; prefer computing once inside EagleModule.

Repeating rotary_pos_emb via cat for each (ttt_step, parallel_draft_step) increases memory. Since EagleModule can compute RoPE when None, consider passing None here (like in pseudo_speculative_generate) and letting the module compute the correct length based on attention_mask.


1014-1018: StaticInferenceContext capacity heuristic.

The 4× multiplier assumes up to 4 ttt_steps. If this becomes configurable, derive max_tokens = input_ids.size(1) * parallel_draft_step * max_ttt_step from config/kwargs rather than hard-coding 4.


1049-1071: Loss/acc blocks are duplicated; collapse into a single loop.

Four near-identical sections differ only by ttt_step index, features input, and decay power. Refactor into a for t in range(1, max_ttt_step+1) that:

  • builds eagle_inputs_t with the previous pre-norm features
  • runs _eagle_forward
  • slices/accumulates loss with decay**t
  • optionally reports accuracy

This reduces maintenance and indexing mistakes.

Also applies to: 1119-1142, 1170-1192, 1220-1242


1449-1451: KV cache note in pseudo-speculate: clarify behavior.

Comment states kv cache unsupported when sequence parallel may be enabled. Consider a runtime check raising a clear error if inference_context is mistakenly passed here, or document the precise conditions (sequence_parallel and/or EP).


142-145: Docstring vs comparator mismatch.

mcore_version_higher_than doc says “at least this version” but uses strict >. Either change the docstring to “greater than” or switch to >= to avoid confusion.

-def mcore_version_higher_than(target_version: str):
-    """Check if megatron-core is least this version."""
-    return Version(megatron.core.__version__) > Version(target_version)
+def mcore_version_higher_than(target_version: str):
+    """Check if megatron-core is greater than this version."""
+    return Version(megatron.core.__version__) > Version(target_version)
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 615f3c0 and b0edb38.

📒 Files selected for processing (2)
  • modelopt/torch/speculative/plugins/megatron_eagle.py (19 hunks)
  • modelopt/torch/speculative/utils.py (2 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
modelopt/torch/speculative/utils.py (2)
modelopt/torch/speculative/plugins/megatron_eagle.py (1)
  • pseudo_speculative_generate (1442-1560)
modelopt/torch/speculative/plugins/transformers.py (1)
  • pseudo_speculative_generate (1064-1141)
modelopt/torch/speculative/plugins/megatron_eagle.py (1)
modelopt/torch/speculative/plugins/transformers.py (2)
  • _get_eagle_module_inputs (469-516)
  • _eagle_forward (727-750)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
  • GitHub Check: linux
  • GitHub Check: wait-checks / wait
  • GitHub Check: wait-checks / wait
  • GitHub Check: build-docs
  • GitHub Check: code-quality
🔇 Additional comments (4)
modelopt/torch/speculative/plugins/megatron_eagle.py (4)

933-960: sequence_len_offset update basis.

You increment inference_context.sequence_len_offset by eagle_inputs["input_ids"].shape[1] (s) each call. Given set_multi_step_attention_mask appends s keys per call, this matches k-len growth. If q_len changes or future refactors alter append width, this will skew. Safer: derive delta from attention_mask.size(-1) change or keep a local appended_k counter.


1102-1111: Top-1 mapping and slicing: verify offsets and bounds.

  • gathered_logits = gathered_logits[i:-1] (and i+1:-1, etc.) assumes enough length; add asserts or clamp to avoid negative-length slices for short sequences.
  • For reduced vocab, eagle_top1 += d2t[eagle_top1] relies on d2t storing offsets; confirm invariant (offset mapping) holds across all training configs.

Also applies to: 1149-1161, 1199-1211, 1249-1261


1499-1504: Parallel draft token scaffolding correctness.

You append parallel_draft_step-1 mask tokens before each EAGLE call and later append the last parallel_draft_step hidden states. Validate that (i) mask tokens align with the same positions EAGLE writes logits for, and (ii) hidden_states growth matches logits selection using [-self.eagle_config.parallel_draft_step:]. Add asserts to catch misalignment.

Also applies to: 1551-1556


29-29: Guard or pin megatron-core version for StaticInferenceContext
Wrap the import of StaticInferenceContext in a try/except (or enforce a minimum megatron-core dependency) and only pass inference_context when TransformerBlock.forward supports it—manually verify your target megatron-core version provides StaticInferenceContext and accepts the inference_context parameter.

Copy link

codecov bot commented Sep 29, 2025

Codecov Report

❌ Patch coverage is 33.33333% with 2 lines in your changes missing coverage. Please review.
✅ Project coverage is 73.79%. Comparing base (cb44c55) to head (aa328ed).
⚠️ Report is 2 commits behind head on main.

Files with missing lines Patch % Lines
modelopt/torch/speculative/utils.py 33.33% 2 Missing ⚠️
Additional details and impacted files
@@           Coverage Diff           @@
##             main     #391   +/-   ##
=======================================
  Coverage   73.79%   73.79%           
=======================================
  Files         171      171           
  Lines       17591    17591           
=======================================
  Hits        12982    12982           
  Misses       4609     4609           

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
modelopt/torch/speculative/plugins/megatron_eagle.py (1)

198-252: Refactor set_multi_step_attention_mask to vectorized rectangular mask
Current iterative mask construction both overwrites an entire query row each loop and omits the final diagonal unmask, breaking causal semantics and leaving the last token unable to attend to new keys. Replace the loop with a single allocation of shape [B, H, S, S*step], copy the original causal mask into block 0, then unmask only the diagonal entries in each extra block:

 def set_multi_step_attention_mask(attn_mask, step):
-    s = attn_mask.shape[-1]
-    for iter in range(2, step + 1):
-        # iter starts from 2nd step
-        mask_0 = attn_mask.clone().detach()
-        mask_0[:, :, iter - 2, :] = True
-        mask_0[:, :, :, :-1] = mask_0[:, :, :, 1:]
-        mask_1 = attn_mask.new_ones(attn_mask.shape[0], attn_mask.shape[1], s, s).bool()
-        for i in range(iter - 1, s - 1):
-            mask_1[:, :, i, i] = False
-
-        attn_mask = torch.cat((mask_0, mask_1), dim=-1)
-
-    return attn_mask
+    b, h, s, _ = attn_mask.shape
+    if step <= 1:
+        return attn_mask
+    new_mask = attn_mask.new_ones(b, h, s, s * step).bool()
+    # block-0: original causal mask
+    new_mask[:, :, :, :s] = attn_mask
+    # extra blocks: only diagonal unmasked
+    idx = torch.arange(s, device=attn_mask.device)
+    for j in range(1, step):
+        new_mask[:, :, idx, j * s + idx] = False
+    return new_mask.contiguous()

Add unit tests for s=4 and step=1..4 to verify both the base causal block and each diagonal-only extension.

🧹 Nitpick comments (6)
modelopt/torch/speculative/plugins/megatron_eagle.py (6)

1014-1018: Avoid magic constant ‘4’ in StaticInferenceContext capacity

StaticInferenceContext(bs, seq_len * parallel_draft_step * 4) bakes in 4 rounds. Prefer deriving from actual number of EAGLE rounds executed (currently 4) so changes don’t silently under/over‑allocate.

Example:

-        eagle_inference_context = StaticInferenceContext(
-            input_ids.shape[0], input_ids.shape[1] * self.eagle_config.parallel_draft_step * 4
-        )
+        num_rounds = 4  # current training rounds; keep in one place
+        max_len = input_ids.shape[1] * self.eagle_config.parallel_draft_step * num_rounds
+        eagle_inference_context = StaticInferenceContext(input_ids.shape[0], max_len)

Even better: centralize num_rounds as a config or class attribute.


1049-1071: Round‑1 logits/pre‑norm handling: collect features or document intent

eagle_hidden_states_0_pre_norm is overwritten each iteration; only the last parallel slice feeds round‑2 via features. If intentional, please comment; if not, collect all pre‑norms and pass the concatenation.

Potential fix:

-            eagle_logits_0 = []
+            eagle_logits_0, pre_norm_list = [], []
@@
-                _, eagle_logits_, eagle_hidden_states_0_pre_norm = self._eagle_forward(
+                _, eagle_logits_, eagle_hidden_states_0_pre_norm = self._eagle_forward(
@@
-                eagle_logits_0.append(eagle_logits_)
+                eagle_logits_0.append(eagle_logits_)
+                pre_norm_list.append(eagle_hidden_states_0_pre_norm)

Then use features=torch.cat(pre_norm_list, dim=0) (adjust downstream slicing accordingly).


1120-1161: Rounds 2 logic mirrors Round 1: consider DRY loop + recheck feature feedthrough

These blocks duplicate Round‑1 with different offsets. Consider a single loop over ttt_step in [1..num_rounds] that:

  • builds inputs,
  • forwards EAGLE with shared eagle_inference_context,
  • accumulates logits/loss/acc with computed offsets.

This reduces risk of indexing drift and concentrates the offset math in one place.


1170-1211: Same as above for Round 3

Consolidate into a generic per‑round loop; reduces maintenance surface and makes mask/offset reasoning easier.


1220-1261: Same as above for Round 4; also check acc slice windows

Unify into loop; confirm slices i+3:-1 vs labels [:, i+4:] as intended.


1551-1556: Hidden state update matches draft expansion

Concatenating the last parallel_draft_step pre‑norm states aligns with the token appends. Add a brief comment to document this coupling.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between b0edb38 and 4061627.

📒 Files selected for processing (1)
  • modelopt/torch/speculative/plugins/megatron_eagle.py (19 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
modelopt/torch/speculative/plugins/megatron_eagle.py (1)
modelopt/torch/speculative/plugins/transformers.py (2)
  • _get_eagle_module_inputs (493-540)
  • _eagle_forward (624-649)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
  • GitHub Check: linux
  • GitHub Check: wait-checks / wait
  • GitHub Check: wait-checks / wait
🔇 Additional comments (9)
modelopt/torch/speculative/plugins/megatron_eagle.py (9)

29-29: LGTM: inference context import

Importing StaticInferenceContext is appropriate for KV‑cache support here.


512-556: Threading inference_context into EagleModule forward: good

Passing inference_context to the decoder is required for KV‑cache. Ensure call sites always pass the same context instance for all per‑round EAGLE calls within a forward.


933-961: KV‑cache offset update: OK; ensure per‑forward context reuse only

Incrementing sequence_len_offset by input_ids.shape[1] per call matches the rectangular mask growth. Just make sure a fresh context is created per model forward and reused across all EAGLE calls within that forward (not shared across microbatches).


1093-1098: Indexing of shifted loss windows: please verify alignment

loss_ = loss_[:, i:] and accumulation into loss[:, i+1:] assume a specific alignment between base logits and each parallel draft i. Add a brief comment and a unit check for off‑by‑one around sequence boundaries.


1102-1111: Accuracy slice may drop/shift one token

gathered_logits = gathered_logits[i:-1] discards the last step. Confirm this matches label slicing labels[:, i+1:]. A small test on a toy batch would help avoid off‑by‑one.


1449-1450: Docstring clarity: good

Explicitly stating no KV‑cache here avoids confusion with the training path.


1537-1543: Top‑k draft slice looks correct

Slicing the last parallel_draft_step steps to form draft_token is consistent with the appended placeholders.


1499-1504: mask_token_ buffers confirmed—no changes needed*

The mask_token_{i} buffers are registered via self.register_buffer in modelopt/torch/speculative/eagle/eagle_model.py, ensuring they exist on the correct dtype and device at runtime.


755-835: Verify mask_token accessibility and RoPE generation

  • Ensure the mask_token_{i} buffers registered on EagleModel are actually available on self in the plugin (or reference them via self.eagle_module) – e.g. add
    if parallel_draft_step > 1:
        assert hasattr(self, f"mask_token_{parallel_draft_step - 2}")
  • Delegate rotary embedding generation to EagleModule by passing rotary_pos_emb = None here and allowing its internal use of inference_context to handle offsets.
  • Confirm that set_multi_step_attention_mask(attn_mask, step) produces the expected rectangular mask shape.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
modelopt/torch/speculative/plugins/megatron_eagle.py (1)

762-795: Preserve history when masking draft tokens

  • In _get_eagle_module_inputs, don’t replace the full input_ids when parallel_draft_step > 1; clone padded_input_ids and overwrite only the rightmost parallel_draft_step – 1 positions with the appropriate mask_token_*, matching the behavior in pseudo_speculative_generate.
  • No mask_token_{i} attributes are defined on self, which will raise AttributeError; add those buffers (e.g. in __init__) or guard against their absence.
🧹 Nitpick comments (4)
modelopt/torch/speculative/plugins/megatron_eagle.py (4)

933-961: KV-cache offset management: OK but side-effect needs guard on overflow.

You bump sequence_len_offset by input_ids.shape[1] each call. Add a max-cap guard to prevent overflow beyond StaticInferenceContext capacity.

-        if inference_context is not None:
-            inference_context.sequence_len_offset += eagle_inputs["input_ids"].shape[1]
+        if inference_context is not None:
+            advance = eagle_inputs["input_ids"].shape[1]
+            inference_context.sequence_len_offset = min(
+                inference_context.sequence_len_offset + advance,
+                inference_context.max_seq_len,
+            )

What is the exact field name for capacity in StaticInferenceContext? Adjust accordingly.


1014-1018: Hard-coded “×4” capacity couples KV-cache to 4 rounds; derive from schedule.

If ttt rounds change, this under/over-allocates. Compute from the actual training schedule:

-        eagle_inference_context = StaticInferenceContext(
-            input_ids.shape[0], input_ids.shape[1] * self.eagle_config.parallel_draft_step * 4
-        )
+        max_rounds = 4  # TODO: derive from a config/arg if you plan to vary rounds
+        max_len = input_ids.shape[1] * self.eagle_config.parallel_draft_step * max_rounds
+        eagle_inference_context = StaticInferenceContext(input_ids.shape[0], max_len)

If you plan to auto-regress >4, promote max_rounds to a config knob.


1499-1504: Mask token attributes must exist and be on the right device/dtype.

Accessing getattr(self, f"mask_token_{i}") without validation can raise. Add a small helper or fallback, and ensure to .to(eagle_ids.device, dtype=eagle_ids.dtype) before assignment.

Also applies to: 1537-1543


1551-1555: Hidden-state growth: ensure KV/activation memory is bounded.

You append next_hidden_states_input slices each step; for large steps this grows O(steps). If steps is bounded small, OK. Otherwise consider reusing a rolling buffer.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between b0edb38 and 4061627.

📒 Files selected for processing (1)
  • modelopt/torch/speculative/plugins/megatron_eagle.py (19 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
modelopt/torch/speculative/plugins/megatron_eagle.py (1)
modelopt/torch/speculative/plugins/transformers.py (2)
  • _get_eagle_module_inputs (493-540)
  • _eagle_forward (624-649)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (6)
  • GitHub Check: example-tests-pr (llm_ptq)
  • GitHub Check: gpu-tests-pr
  • GitHub Check: multi-torch (27)
  • GitHub Check: multi-torch (26)
  • GitHub Check: multi-transformers (min)
  • GitHub Check: windows
🔇 Additional comments (4)
modelopt/torch/speculative/plugins/megatron_eagle.py (4)

29-29: Good addition: inference context import.

Importing StaticInferenceContext here makes sense for threading KV-cache state.


512-512: Threading inference_context is correct.

Forwarding the context into decoder is consistent with KV-cache usage.

Confirm that TransformerBlock in your mcore version accepts inference_context; older versions may ignore or raise.

Also applies to: 553-554


826-835: Mask/pos-emb replication factor looks right; verify rotary length matches key length.

You replicate rotary_pos_emb by the effective step count. Ensure decoder expects k_len rotary embeddings equal to attention_mask K dimension.


1093-1098: Loss windowing looks consistent; verify off-by-one with new concatenated features.

Given the concatenation change above, double-check the i-based slicing still aligns with label offsets.

Also applies to: 1143-1148, 1193-1198, 1243-1248

Comment on lines 242 to 252
for iter in range(2, step + 1):
# iter starts from 2nd step
mask_0 = attn_mask.clone().detach()
mask_0[:, :, iter - 2, :] = True
mask_0[:, :, :, :-1] = mask_0[:, :, :, 1:]
mask_1 = attn_mask.new_ones(attn_mask.shape[0], attn_mask.shape[1], s, s).bool()
for i in range(iter - 1, s - 1):
mask_1[:, :, i, i] = False

attn_mask = torch.cat((mask_0, mask_1), dim=-1)

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

Multi-step mask: prevent OOB and vectorize; current loop risks indexing past q_len and is slow.

  • iter-2 row index will go out of bounds when step > s.
  • The inner for-loop over i is O(s) per step; can be vectorized.

Apply bounds + vectorization:

-    for iter in range(2, step + 1):
+    s = attn_mask.shape[-1]
+    # Cap steps to avoid indexing beyond query length.
+    for iter in range(2, min(step, s) + 1):
         # iter starts from 2nd step
         mask_0 = attn_mask.clone().detach()
         mask_0[:, :, iter - 2, :] = True
         mask_0[:, :, :, :-1] = mask_0[:, :, :, 1:]
-        mask_1 = attn_mask.new_ones(attn_mask.shape[0], attn_mask.shape[1], s, s).bool()
-        for i in range(iter - 1, s - 1):
-            mask_1[:, :, i, i] = False
+        mask_1 = attn_mask.new_ones(attn_mask.shape[0], attn_mask.shape[1], s, s, dtype=torch.bool)
+        idx = torch.arange(iter - 1, s - 1, device=attn_mask.device)
+        mask_1[:, :, idx, idx] = False

Please add a small unit test to exercise step values {1, 2, s, s+3}.


🏁 Script executed:

#!/bin/bash
sed -n '230,270p' modelopt/torch/speculative/plugins/megatron_eagle.py

Length of output: 2355


🏁 Script executed:

#!/bin/bash
sed -n '150,300p' modelopt/torch/speculative/plugins/megatron_eagle.py

Length of output: 9208


Cap multi-step range to sequence length and vectorize mask construction

  • Prevent OOB at mask_0[:, :, iter-2, :] by capping the loop:
    # s = attn_mask.shape[-1]  # already defined above
    for iter in range(2, min(step, s) + 1):
        …
  • Replace the per-element diagonal loop with a batched index assignment:
    mask_1 = attn_mask.new_ones(batch, heads, s, s, dtype=torch.bool)
    idx = torch.arange(iter - 1, s - 1, device=attn_mask.device)
    mask_1[:, :, idx, idx] = False
  • Add unit tests covering step = 1, 2, s, s+3 to verify no OOB and correct masking.
🤖 Prompt for AI Agents
In modelopt/torch/speculative/plugins/megatron_eagle.py around lines 242 to 252,
the for-loop can go out-of-bounds when step > sequence length and the mask_1
diagonal is built with a Python loop; change the loop range to cap at the
sequence length (use range(2, min(step, s) + 1)) and replace the per-element
diagonal loop with a vectorized batched index assignment using torch.arange on
the correct device/dtype to set mask_1[:, :, idx, idx] = False, ensure mask_1 is
created with the same device and boolean dtype as attn_mask, and add unit tests
for step values 1, 2, s, and s+3 to assert no OOB and correct mask
shapes/values.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (1)
modelopt/torch/speculative/plugins/megatron_eagle.py (1)

243-249: Critical: Out-of-bounds access when step > sequence_length.

When step > s (where s = attn_mask.shape[-1]), line 245 will attempt to index mask_0[:, :, step_idx, :] with step_idx >= s, causing an out-of-bounds error. Additionally, the inner loop at lines 248-249 is O(s) per iteration and should be vectorized.

Apply this fix to cap the loop range and vectorize the diagonal assignment:

 s = attn_mask.shape[-1]
-for step_idx in range(step):
+# Cap step_idx to avoid indexing beyond query length
+for step_idx in range(min(step, s)):
     mask_0 = attn_mask.clone().detach()
     mask_0[:, :, step_idx, :] = True
     mask_0[:, :, :, :-1] = mask_0[:, :, :, 1:]
-    mask_1 = attn_mask.new_ones(attn_mask.shape[0], attn_mask.shape[1], s, s).bool()
-    for i in range(step_idx + 1, s - 1):
-        mask_1[:, :, i, i] = False
+    mask_1 = attn_mask.new_ones(attn_mask.shape[0], attn_mask.shape[1], s, s, dtype=torch.bool)
+    # Vectorize diagonal assignment
+    if step_idx + 1 < s - 1:
+        idx = torch.arange(step_idx + 1, s - 1, device=attn_mask.device)
+        mask_1[:, :, idx, idx] = False
 
     attn_mask = torch.cat((mask_0, mask_1), dim=-1)

Add unit tests covering step values of 1, 2, s, and s+3 to verify correctness.

Based on the past review comment, this issue was previously identified but not yet resolved.

🧹 Nitpick comments (1)
modelopt/torch/speculative/plugins/megatron_eagle.py (1)

826-834: Rename parallel_draft_step parameter for clarity

The step calculation (ttt_step * eagle_config.parallel_draft_step + draft_index) is correct; rename the parallel_draft_step parameter (and its occurrences and docstring) to draft_index to distinguish it from eagle_config.parallel_draft_step.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 3c06000 and 22441ca.

📒 Files selected for processing (1)
  • modelopt/torch/speculative/plugins/megatron_eagle.py (17 hunks)
🧰 Additional context used
🧠 Learnings (2)
📓 Common learnings
Learnt from: yeyu-nvidia
PR: NVIDIA/TensorRT-Model-Optimizer#391
File: modelopt/torch/speculative/plugins/megatron_eagle.py:1049-1071
Timestamp: 2025-09-29T17:40:37.310Z
Learning: In EAGLE parallel draft implementation in modelopt/torch/speculative/plugins/megatron_eagle.py, overwriting eagle_hidden_states_*_pre_norm variables within the per-draft loop (rather than accumulating all draft pre-norm states) is the expected behavior for parallel draft functionality.
📚 Learning: 2025-09-29T17:40:37.310Z
Learnt from: yeyu-nvidia
PR: NVIDIA/TensorRT-Model-Optimizer#391
File: modelopt/torch/speculative/plugins/megatron_eagle.py:1049-1071
Timestamp: 2025-09-29T17:40:37.310Z
Learning: In EAGLE parallel draft implementation in modelopt/torch/speculative/plugins/megatron_eagle.py, overwriting eagle_hidden_states_*_pre_norm variables within the per-draft loop (rather than accumulating all draft pre-norm states) is the expected behavior for parallel draft functionality.

Applied to files:

  • modelopt/torch/speculative/plugins/megatron_eagle.py
🧬 Code graph analysis (1)
modelopt/torch/speculative/plugins/megatron_eagle.py (1)
modelopt/torch/speculative/plugins/transformers.py (2)
  • _get_eagle_module_inputs (493-540)
  • _eagle_forward (624-649)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
  • GitHub Check: linux
  • GitHub Check: wait-checks / wait
  • GitHub Check: wait-checks / wait
  • GitHub Check: code-quality
  • GitHub Check: build-docs
🔇 Additional comments (9)
modelopt/torch/speculative/plugins/megatron_eagle.py (9)

29-29: LGTM: Import addition supports KV cache feature.

The StaticInferenceContext import is necessary for the new KV cache functionality described in the PR objectives.


143-143: LGTM: Improved docstring accuracy.

The docstring now correctly describes the strict inequality check (>).


512-512: LGTM: Signature addition supports KV cache propagation.

The inference_context parameter enables KV cache management throughout the EAGLE module forward pass, aligning with the PR objectives.


796-805: Correct: Feature slicing preserves last s tokens only.

The slicing feature = gathered_features[-s:] is correct, as previous parallel draft tokens are preserved in the KV cache (as confirmed in past review discussion). The KV cache mechanism maintains the full context across iterations.

Based on learnings from prior review.


947-949: LGTM: KV cache offset correctly updated after each EAGLE call.

The sequence_len_offset increment ensures the KV cache tracks the current position across multiple EAGLE module invocations, which is essential for parallel draft functionality.


1107-1129: LGTM: Per-draft accuracy calculation correctly indexes logits and labels.

The accuracy computation maintains consistent next-token prediction offsets across drafts:

  • Logits sliced from [i + ttt_step : -1]
  • Labels sliced from [:, i + ttt_step + 1 :]

The 1-position shift correctly aligns predictions with target tokens.


1361-1365: LGTM: Parallel draft token generation correctly appends mask tokens.

The loop appends parallel_draft_step - 1 mask tokens and corresponding hidden states, which together with the one real token forms a complete set of parallel_draft_step positions for EAGLE to predict.


1399-1418: LGTM: Draft token extraction correctly handles parallel drafts.

The slicing [-self.eagle_config.parallel_draft_step :] extracts all parallel draft tokens and their corresponding hidden states, maintaining consistency with the parallel draft architecture.


1099-1105: I’ve extracted the EAGLE loss helper to inspect its behavior. Let me know the output so I can verify the accumulation logic.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (1)
modelopt/torch/speculative/plugins/megatron_eagle.py (1)

242-253: Fix OOB + in-place aliasing in multi-step mask; vectorize diagonal writes.

  • Range loop can index past q_len when step > s; cap to s.
  • In-place shift uses overlapping slices; RHS must be cloned to avoid undefined results.
  • Build mask_1 with vectorized diagonal assignment; avoid Python loop.

Apply:

-    s = attn_mask.shape[-1]
-    for step_idx in range(step):
-        mask_0 = attn_mask.clone().detach()
-        mask_0[:, :, step_idx, :] = True
-        mask_0[:, :, :, :-1] = mask_0[:, :, :, 1:]
-        mask_1 = attn_mask.new_ones(attn_mask.shape[0], attn_mask.shape[1], s, s).bool()
-        for i in range(step_idx + 1, s - 1):
-            mask_1[:, :, i, i] = False
-        attn_mask = torch.cat((mask_0, mask_1), dim=-1)
+    s = attn_mask.shape[-2]  # query length
+    iter_max = min(step, s)
+    for step_idx in range(iter_max):
+        mask_0 = attn_mask.clone()
+        mask_0[:, :, step_idx, :] = True
+        # Avoid aliasing on overlapping slices
+        mask_0[:, :, :, :-1] = mask_0[:, :, :, 1:].clone()
+        # Vectorized diagonal
+        mask_1 = torch.ones(
+            (*attn_mask.shape[:2], s, s), dtype=torch.bool, device=attn_mask.device
+        )
+        if step_idx + 1 < s - 1:
+            idx = torch.arange(step_idx + 1, s - 1, device=attn_mask.device)
+            mask_1[:, :, idx, idx] = False
+        attn_mask = torch.cat((mask_0, mask_1), dim=-1)

Also consider preallocating mask_1 outside the loop and only rewriting the diagonal slice per iteration.

🧹 Nitpick comments (3)
modelopt/torch/speculative/plugins/megatron_eagle.py (3)

830-835: Add a shape assertion to ensure rotary_pos_emb matches expanded k_len.

Guard against subtle mismatches after repetition.

 e.g.,
+assert (
+    eagle_inputs["rotary_pos_emb"].shape[0] == eagle_inputs["attention_mask"].shape[-1]
+), "rotary_pos_emb length must equal attention_mask k_len"

1048-1066: Prefer padding labels with ignore_index instead of 0.

Zero may be a valid token id; using it biases loss. Pad with the model’s ignore index (commonly -100).

-                right_token_pad = torch.zeros(
-                    (labels.shape[0], 1),
-                    dtype=labels.dtype,
-                    device=labels.device,
-                )
+                ignore_index = getattr(self, "label_ignore", -100)
+                right_token_pad = torch.full(
+                    (labels.shape[0], 1), fill_value=ignore_index, dtype=labels.dtype, device=labels.device
+                )
                 labels = torch.cat((labels, right_token_pad), dim=-1)

Please confirm compute_language_model_loss uses ignore_index consistently.


947-950: Sequence length accounting in KV cache: add a safety check.

Incrementing by input_ids.shape[1] per call assumes exactly s new KV entries each call. Add a guard to avoid exceeding the preallocated max.

 if inference_context is not None:
-    inference_context.sequence_len_offset += eagle_inputs["input_ids"].shape[1]
+    inference_context.sequence_len_offset += eagle_inputs["input_ids"].shape[1]
+    assert (
+        inference_context.sequence_len_offset
+        <= inference_context.max_sequence_length
+    ), "StaticInferenceContext capacity exceeded"

Also applies to: 1015-1019

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 22441ca and b88b27a.

📒 Files selected for processing (1)
  • modelopt/torch/speculative/plugins/megatron_eagle.py (17 hunks)
🧰 Additional context used
🧠 Learnings (2)
📓 Common learnings
Learnt from: yeyu-nvidia
PR: NVIDIA/TensorRT-Model-Optimizer#391
File: modelopt/torch/speculative/plugins/megatron_eagle.py:1049-1071
Timestamp: 2025-09-29T17:40:37.310Z
Learning: In EAGLE parallel draft implementation in modelopt/torch/speculative/plugins/megatron_eagle.py, overwriting eagle_hidden_states_*_pre_norm variables within the per-draft loop (rather than accumulating all draft pre-norm states) is the expected behavior for parallel draft functionality.
📚 Learning: 2025-09-29T17:40:37.310Z
Learnt from: yeyu-nvidia
PR: NVIDIA/TensorRT-Model-Optimizer#391
File: modelopt/torch/speculative/plugins/megatron_eagle.py:1049-1071
Timestamp: 2025-09-29T17:40:37.310Z
Learning: In EAGLE parallel draft implementation in modelopt/torch/speculative/plugins/megatron_eagle.py, overwriting eagle_hidden_states_*_pre_norm variables within the per-draft loop (rather than accumulating all draft pre-norm states) is the expected behavior for parallel draft functionality.

Applied to files:

  • modelopt/torch/speculative/plugins/megatron_eagle.py
🧬 Code graph analysis (1)
modelopt/torch/speculative/plugins/megatron_eagle.py (1)
modelopt/torch/speculative/plugins/transformers.py (2)
  • _get_eagle_module_inputs (493-540)
  • _eagle_forward (624-649)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
  • GitHub Check: linux
  • GitHub Check: wait-checks / wait
  • GitHub Check: wait-checks / wait
  • GitHub Check: code-quality
  • GitHub Check: build-docs
🔇 Additional comments (4)
modelopt/torch/speculative/plugins/megatron_eagle.py (4)

796-806: Feature slicing logic is correct with KV cache; keep last s only.

Using only the last s features for the next ttt_step is expected because previous drafts are preserved in the KV cache; the slice gathered_features[-s:] is correct.

Based on learnings

Also applies to: 806-819, 821-825


1067-1070: Pre‑norm overwrite per draft is intentional; no change needed.

Overwriting eagle_hidden_states_*_pre_norm inside the per‑draft loop is expected for parallel draft.

Based on learnings

Also applies to: 1092-1092


1311-1312: KV cache deliberately disabled in pseudo_speculative_generate; good.

Not passing inference_context here avoids cache/sequence‑parallel interplay; the guard comment is clear.

Also applies to: 1387-1390


29-29: Guard StaticInferenceContext import and conditional inference_context usage for backward compatibility. Verify the minimum megatron-core version that introduces StaticInferenceContext and adds the inference_context parameter to TransformerBlock.forward; if you must support earlier versions, apply the import guard and conditional kwarg wiring as shown above in all relevant blocks (lines 512–515, 553, 933–944, 1015–1019).

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (1)
modelopt/torch/speculative/plugins/megatron_eagle.py (1)

242-252: Prevent potential OOB and consider vectorization.

The loop iterates for step_idx in range(step) and indexes mask_0[:, :, step_idx, :] at line 245. If step exceeds s (the sequence length), this will cause an out-of-bounds error. Additionally, the inner loop constructing the diagonal mask (lines 248-249) could be vectorized for better performance.

Apply this defensive cap and vectorization:

     s = attn_mask.shape[-1]
-    for step_idx in range(step):
+    # Cap step to prevent OOB indexing
+    for step_idx in range(min(step, s)):
         mask_0 = attn_mask.clone().detach()
         mask_0[:, :, step_idx, :] = True
         mask_0[:, :, :, :-1] = mask_0[:, :, :, 1:]
         mask_1 = attn_mask.new_ones(attn_mask.shape[0], attn_mask.shape[1], s, s).bool()
-        for i in range(step_idx + 1, s - 1):
-            mask_1[:, :, i, i] = False
+        # Vectorize diagonal assignment
+        if step_idx + 1 < s - 1:
+            idx = torch.arange(step_idx + 1, s - 1, device=attn_mask.device)
+            mask_1[:, :, idx, idx] = False
 
         attn_mask = torch.cat((mask_0, mask_1), dim=-1)

Note: This addresses the concern raised in previous reviews about OOB and performance.

🧹 Nitpick comments (1)
modelopt/torch/speculative/plugins/megatron_eagle.py (1)

142-144: Clarify version comparison semantics.

The function name mcore_version_higher_than and implementation using > indicate a strict comparison, but "greater than" can be ambiguous. Consider updating the docstring to explicitly state "strictly greater than" for clarity.

-    """Check if megatron-core is greater than this version."""
+    """Check if megatron-core version is strictly greater than the target version."""
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between b88b27a and d6f1489.

📒 Files selected for processing (1)
  • modelopt/torch/speculative/plugins/megatron_eagle.py (17 hunks)
🧰 Additional context used
🧠 Learnings (2)
📓 Common learnings
Learnt from: yeyu-nvidia
PR: NVIDIA/TensorRT-Model-Optimizer#391
File: modelopt/torch/speculative/plugins/megatron_eagle.py:1049-1071
Timestamp: 2025-09-29T17:40:37.310Z
Learning: In EAGLE parallel draft implementation in modelopt/torch/speculative/plugins/megatron_eagle.py, overwriting eagle_hidden_states_*_pre_norm variables within the per-draft loop (rather than accumulating all draft pre-norm states) is the expected behavior for parallel draft functionality.
📚 Learning: 2025-09-29T17:40:37.310Z
Learnt from: yeyu-nvidia
PR: NVIDIA/TensorRT-Model-Optimizer#391
File: modelopt/torch/speculative/plugins/megatron_eagle.py:1049-1071
Timestamp: 2025-09-29T17:40:37.310Z
Learning: In EAGLE parallel draft implementation in modelopt/torch/speculative/plugins/megatron_eagle.py, overwriting eagle_hidden_states_*_pre_norm variables within the per-draft loop (rather than accumulating all draft pre-norm states) is the expected behavior for parallel draft functionality.

Applied to files:

  • modelopt/torch/speculative/plugins/megatron_eagle.py
🧬 Code graph analysis (1)
modelopt/torch/speculative/plugins/megatron_eagle.py (1)
modelopt/torch/speculative/plugins/transformers.py (2)
  • _get_eagle_module_inputs (493-540)
  • _eagle_forward (624-649)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
  • GitHub Check: wait-checks / wait
  • GitHub Check: linux
  • GitHub Check: wait-checks / wait
🔇 Additional comments (8)
modelopt/torch/speculative/plugins/megatron_eagle.py (8)

947-950: LGTM: Proper KV cache offset tracking.

The sequence_len_offset update after each EAGLE module call correctly maintains the KV cache index for incremental generation. This aligns with the PR objective of adding KV cache support in EAGLE training.


796-825: LGTM: Proper sequence parallel handling.

The gather/scatter pattern correctly handles sequence parallelism, and the feature shifting logic for ttt_step > 0 properly implements the autoregressive pattern for parallel draft generation.


1107-1129: LGTM: Per-draft accuracy reporting properly implemented.

The accuracy computation correctly handles per-draft evaluation with proper vocab mapping and index alignment. The reporting provides useful per-draft metrics for monitoring training progress.


1360-1424: LGTM: Parallel draft generation properly implemented.

The restructured loop correctly handles parallel draft token generation:

  1. Prepares input by appending mask tokens for parallel drafts (lines 1361-1365)
  2. Selects multiple draft tokens from the output (lines 1399-1405)
  3. Properly cleans up temporary mask tokens before next iteration (lines 1411-1414)
  4. Maintains correct hidden state sequence (lines 1418-1423)

This aligns with the PR objective of supporting parallel draft with autoregression.


1048-1065: LGTM: Label padding handles offline training edge case.

The label padding logic correctly handles the case where offline training labels are one token shorter. The comment appropriately documents the small training error this introduces when logit_distillation=False. The conditional loss zeroing for frozen base model is also correct.


1015-1019: LGTM: KV cache properly sized for parallel drafts.

The StaticInferenceContext initialization correctly allocates cache capacity for all parallel draft steps and TTT steps. This enables the rectangular attention mask optimization mentioned in the PR objectives, reducing memory consumption and avoiding redundant computation.


785-794: No action needed on mask tokens. Initialization in eagle_model.py covers all mask_token_{i} used by the plugin.


1099-1105: Loss indexing and decay weighting align correctly with autoregressive offsets; no changes needed.

…ach ttt step, only the non_parallel tokens from previous ttt are used as context

Signed-off-by: Ye Yu <[email protected]>
Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🧹 Nitpick comments (1)
modelopt/torch/speculative/plugins/megatron_eagle.py (1)

828-841: Potential None dereference in gathered_features

Line 838 accesses gathered_features[:-1, :, :] when ttt_step > 0, but this assumes gathered_features is not None. While the caller may ensure this, adding a defensive check would improve robustness.

Consider adding an assertion or guard:

         eagle_inputs["hidden_states"] = (
             gathered_hidden_states
             if ttt_step == 0
             else torch.cat(
                 (
                     torch.zeros(
                         (1, b, h),
                         dtype=hidden_states.dtype,
                         device=hidden_states.device,
                     ),
+                    # gathered_features should not be None when ttt_step > 0
                     gathered_features[:-1, :, :],  # type: ignore[index]
                 )
             )
         )
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between b47928e and 44e0eb1.

📒 Files selected for processing (1)
  • modelopt/torch/speculative/plugins/megatron_eagle.py (16 hunks)
🧰 Additional context used
🧠 Learnings (2)
📓 Common learnings
Learnt from: yeyu-nvidia
PR: NVIDIA/TensorRT-Model-Optimizer#391
File: modelopt/torch/speculative/plugins/megatron_eagle.py:1049-1071
Timestamp: 2025-09-29T17:40:37.310Z
Learning: In EAGLE parallel draft implementation in modelopt/torch/speculative/plugins/megatron_eagle.py, overwriting eagle_hidden_states_*_pre_norm variables within the per-draft loop (rather than accumulating all draft pre-norm states) is the expected behavior for parallel draft functionality.
📚 Learning: 2025-09-29T17:40:37.310Z
Learnt from: yeyu-nvidia
PR: NVIDIA/TensorRT-Model-Optimizer#391
File: modelopt/torch/speculative/plugins/megatron_eagle.py:1049-1071
Timestamp: 2025-09-29T17:40:37.310Z
Learning: In EAGLE parallel draft implementation in modelopt/torch/speculative/plugins/megatron_eagle.py, overwriting eagle_hidden_states_*_pre_norm variables within the per-draft loop (rather than accumulating all draft pre-norm states) is the expected behavior for parallel draft functionality.

Applied to files:

  • modelopt/torch/speculative/plugins/megatron_eagle.py
🧬 Code graph analysis (1)
modelopt/torch/speculative/plugins/megatron_eagle.py (1)
modelopt/torch/speculative/plugins/transformers.py (2)
  • _get_eagle_module_inputs (501-548)
  • _eagle_forward (632-657)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
  • GitHub Check: linux
  • GitHub Check: wait-checks / wait
  • GitHub Check: wait-checks / wait
  • GitHub Check: build-docs
  • GitHub Check: code-quality
🔇 Additional comments (7)
modelopt/torch/speculative/plugins/megatron_eagle.py (7)

29-29: LGTM: StaticInferenceContext import

The import is correctly added to support kv cache management in EAGLE training and inference.


143-143: LGTM: Documentation clarification

The docstring update clarifies the version comparison semantics.


948-982: LGTM: Proper inference_context handling

The changes correctly:

  1. Thread inference_context through the EAGLE forward pass
  2. Update sequence_len_offset after each EAGLE module call (lines 968-970) to maintain kv cache consistency
  3. Return all necessary hidden states for the next iteration

1090-1125: LGTM: Nested loop structure for parallel draft

The nested loop structure correctly implements the parallel draft with auto-regression:

  1. Outer loop (line 1090) iterates over ttt_steps
  2. Inner loop (line 1092) iterates over parallel_draft_step to generate multiple tokens per step
  3. Line 1112-1113 correctly preserves only the first draft's pre-norm states (per learnings)
  4. Lines 1119-1124 properly adjust sequence_len_offset to discard intermediate drafts' kv cache entries

Based on learnings


1131-1153: LGTM: Loss and accuracy computation

The loss computation and accuracy reporting correctly:

  1. Apply loss decay factor: eagle_loss_decay_factor ** (ttt_step + i) (line 1136)
  2. Handle loss offsets properly for each draft position (lines 1134-1137)
  3. Map draft tokens to target vocabulary when needed (line 1148)
  4. Compute top-1 accuracy for each parallel draft (lines 1139-1153)

1342-1343: LGTM: Clarified kv cache constraint

The updated docstring correctly documents that kv cache is not supported when sequence parallel is enabled.


1392-1454: LGTM: Parallel draft token generation

The implementation correctly handles parallel draft generation:

  1. Lines 1392-1396: Add parallel_draft_step - 1 mask tokens and duplicate hidden states
  2. Lines 1430-1436: Extract parallel_draft_step draft tokens from logits
  3. Lines 1442-1445: Clean up mask tokens before adding real draft tokens
  4. Lines 1449-1453: Concatenate appropriate hidden states for the next iteration

The logic is consistent and aligns with the parallel draft mechanism.

Comment on lines +267 to +276
for step_idx in range(step):
mask_0 = attn_mask.clone().detach()
mask_0[:, :, step_idx, :] = True
mask_0[:, :, :, :-1] = mask_0[:, :, :, 1:]
mask_1 = attn_mask.new_ones(attn_mask.shape[0], attn_mask.shape[1], s, s).bool()
for i in range(step_idx + 1, s - 1):
mask_1[:, :, i, i] = False

attn_mask = torch.cat((mask_0, mask_1), dim=-1)

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major

🧩 Analysis chain

Verify bounds and consider vectorization

The loop construction may encounter issues:

  1. Line 272: When step_idx >= s - 1, the range range(step_idx + 1, s - 1) becomes empty, which is safe but may indicate unexpected behavior for large step values.
  2. The inner loop at line 272-273 is O(s) per step and could be vectorized for efficiency.

Consider:

  • Adding validation that step is within expected bounds relative to sequence length
  • Vectorizing the diagonal assignment using torch.arange

Based on past review comments, this was previously flagged. Please verify the current implementation handles all edge cases correctly, especially when step > s.


🏁 Script executed:

#!/bin/bash
# Search for calls to set_multi_step_attention_mask to understand expected step values
rg -n -A 3 -B 3 'set_multi_step_attention_mask' --type=py

Length of output: 1221


🏁 Script executed:

#!/bin/bash
sed -n '190,260p' modelopt/torch/speculative/plugins/megatron_eagle.py
sed -n '260,340p' modelopt/torch/speculative/plugins/megatron_eagle.py

Length of output: 10130


Validate step bounds and vectorize diagonal mask

  • In set_multi_step_attention_mask (modelopt/torch/speculative/plugins/megatron_eagle.py), add at the top:
    s = attn_mask.shape[-1]
    if step > s:
        raise ValueError(f"step ({step}) must be ≤ sequence length ({s})")
  • Replace the inner for i in range(step_idx+1, s-1): mask_1[..., i, i] = False loop with a vectorized assignment, e.g.:
    idx = torch.arange(step_idx+1, s-1, device=attn_mask.device)
    mask_1[..., idx, idx] = False
🤖 Prompt for AI Agents
In modelopt/torch/speculative/plugins/megatron_eagle.py around lines 267–276,
validate that step does not exceed the sequence length by reading s from
attn_mask.shape[-1] and raising a ValueError if step > s; then replace the
per-index Python loop that clears the diagonal in mask_1 with a vectorized
assignment using a torch.arange of indices on the same device as attn_mask to
set mask_1[..., idx, idx] = False (ensure mask_1 remains boolean and use
attn_mask.device for the arange).

Comment on lines +1036 to +1040
# EAGLE kv cache
eagle_inference_context = StaticInferenceContext(
input_ids.shape[0],
input_ids.shape[1] * (self.eagle_config.parallel_draft_step + ttt_steps - 1),
)
Copy link

@coderabbitai coderabbitai bot Oct 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

Verify kv cache buffer size calculation

The buffer size is calculated as input_ids.shape[1] * (parallel_draft_step + ttt_steps - 1). Let me verify this accounts for all tokens correctly:

  • Initial tokens: input_ids.shape[1]
  • Each ttt_step generates: parallel_draft_step tokens
  • But only 1 token per ttt_step is kept for the next iteration
  • Total additional tokens: approximately ttt_steps * parallel_draft_step

The formula (parallel_draft_step + ttt_steps - 1) seems like it might underestimate when both parameters are large.

Please verify that the buffer size correctly accounts for the maximum number of tokens that will be present in the kv cache at any point during the forward pass.


🏁 Script executed:

#!/bin/bash
# Search for StaticInferenceContext usage and buffer size patterns
rg -n -B 3 -A 3 'StaticInferenceContext' --type=py

Length of output: 4446


Fix kv-cache buffer size calculation
The size passed to StaticInferenceContext at modelopt/torch/speculative/plugins/megatron_eagle.py:1036–1040 is wrong—multiplying the original seq-len by (parallel_draft_step + ttt_steps - 1) miscomputes the total tokens processed. It should be based on addition, e.g.:

max_seq_len = input_ids.shape[1] + self.eagle_config.parallel_draft_step * ttt_steps

so the KV cache can hold the original tokens plus all generated ones.

🤖 Prompt for AI Agents
In modelopt/torch/speculative/plugins/megatron_eagle.py around lines 1036 to
1040, the KV-cache buffer size is computed incorrectly by multiplying the
original sequence length by (parallel_draft_step + ttt_steps - 1); change the
calculation to compute the maximum sequence length as the original input length
plus all generated tokens: max_seq_len = input_ids.shape[1] +
self.eagle_config.parallel_draft_step * ttt_steps, and pass that value as the
second argument to StaticInferenceContext so the KV cache can hold the original
tokens plus all generated tokens.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, the original value is correct

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yeyu-nvidia, thank you for the clarification! I see now that the calculation is correct.

Looking at the code flow more carefully, I can see that:

  • Each ttt_step processes input_ids.shape[1] * parallel_draft_step tokens
  • After each iteration, sequence_len_offset is decremented by input_ids.shape[1] * (parallel_draft_step - 1) (line 1067-1070)
  • This means only input_ids.shape[1] tokens worth of KV cache is retained per iteration
  • The total maximum KV cache size needed is therefore: input_ids.shape[1] * (ttt_steps + parallel_draft_step - 1)

The original formula correctly accounts for the fact that KV cache entries are being reused/discarded during the loop, not simply accumulated. My apologies for the confusion!


✏️ Learnings added
Learnt from: yeyu-nvidia
PR: NVIDIA/TensorRT-Model-Optimizer#391
File: modelopt/torch/speculative/plugins/megatron_eagle.py:1036-1040
Timestamp: 2025-10-01T21:34:30.819Z
Learning: In EAGLE parallel draft with KV cache (modelopt/torch/speculative/plugins/megatron_eagle.py), the StaticInferenceContext buffer size is correctly calculated as `input_ids.shape[1] * (parallel_draft_step + ttt_steps - 1)` because each ttt_step iteration discards (parallel_draft_step - 1) worth of KV entries by decrementing sequence_len_offset, so only one token's worth of KV cache is retained per ttt_step.

Note: Learnings are effective only in the context of similar code segments. To apply general coding guidelines, please configure review instructions. You can manage existing learnings in the app.


🧠 Learnings used
Learnt from: yeyu-nvidia
PR: NVIDIA/TensorRT-Model-Optimizer#391
File: modelopt/torch/speculative/plugins/megatron_eagle.py:1049-1071
Timestamp: 2025-09-29T17:40:37.310Z
Learning: In EAGLE parallel draft implementation in modelopt/torch/speculative/plugins/megatron_eagle.py, overwriting eagle_hidden_states_*_pre_norm variables within the per-draft loop (rather than accumulating all draft pre-norm states) is the expected behavior for parallel draft functionality.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (1)
modelopt/torch/speculative/plugins/megatron_eagle.py (1)

267-276: Validate step parameter and vectorize diagonal mask construction.

The loop construction may encounter issues:

  1. Out-of-bounds risk: When step > s (sequence length), line 269 mask_0[:, :, step_idx, :] will index beyond the tensor dimension.
  2. Performance: The inner loop at lines 272-273 is O(s) per step and could be vectorized.

Apply this fix to add bounds validation and vectorize:

 def set_multi_step_attention_mask(attn_mask, step):
+    s = attn_mask.shape[-1]
+    if step > s:
+        raise ValueError(f"step ({step}) must be ≤ sequence length ({s})")
+    
-    s = attn_mask.shape[-1]
     for step_idx in range(step):
         mask_0 = attn_mask.clone().detach()
         mask_0[:, :, step_idx, :] = True
         mask_0[:, :, :, :-1] = mask_0[:, :, :, 1:]
         mask_1 = attn_mask.new_ones(attn_mask.shape[0], attn_mask.shape[1], s, s).bool()
-        for i in range(step_idx + 1, s - 1):
-            mask_1[:, :, i, i] = False
+        if step_idx + 1 < s - 1:
+            idx = torch.arange(step_idx + 1, s - 1, device=attn_mask.device)
+            mask_1[:, :, idx, idx] = False
 
         attn_mask = torch.cat((mask_0, mask_1), dim=-1)

This resolves the concerns from the previous review and adds safety + performance improvements.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 8212265 and aa328ed.

📒 Files selected for processing (1)
  • modelopt/torch/speculative/plugins/megatron_eagle.py (17 hunks)
🧰 Additional context used
🧠 Learnings (3)
📓 Common learnings
Learnt from: yeyu-nvidia
PR: NVIDIA/TensorRT-Model-Optimizer#391
File: modelopt/torch/speculative/plugins/megatron_eagle.py:1036-1040
Timestamp: 2025-10-01T21:34:30.854Z
Learning: In EAGLE parallel draft with KV cache (modelopt/torch/speculative/plugins/megatron_eagle.py), the StaticInferenceContext buffer size is correctly calculated as `input_ids.shape[1] * (parallel_draft_step + ttt_steps - 1)` because each ttt_step iteration discards (parallel_draft_step - 1) worth of KV entries by decrementing sequence_len_offset, so only one token's worth of KV cache is retained per ttt_step.
📚 Learning: 2025-10-01T21:34:30.854Z
Learnt from: yeyu-nvidia
PR: NVIDIA/TensorRT-Model-Optimizer#391
File: modelopt/torch/speculative/plugins/megatron_eagle.py:1036-1040
Timestamp: 2025-10-01T21:34:30.854Z
Learning: In EAGLE parallel draft with KV cache (modelopt/torch/speculative/plugins/megatron_eagle.py), the StaticInferenceContext buffer size is correctly calculated as `input_ids.shape[1] * (parallel_draft_step + ttt_steps - 1)` because each ttt_step iteration discards (parallel_draft_step - 1) worth of KV entries by decrementing sequence_len_offset, so only one token's worth of KV cache is retained per ttt_step.

Applied to files:

  • modelopt/torch/speculative/plugins/megatron_eagle.py
📚 Learning: 2025-09-29T17:40:37.310Z
Learnt from: yeyu-nvidia
PR: NVIDIA/TensorRT-Model-Optimizer#391
File: modelopt/torch/speculative/plugins/megatron_eagle.py:1049-1071
Timestamp: 2025-09-29T17:40:37.310Z
Learning: In EAGLE parallel draft implementation in modelopt/torch/speculative/plugins/megatron_eagle.py, overwriting eagle_hidden_states_*_pre_norm variables within the per-draft loop (rather than accumulating all draft pre-norm states) is the expected behavior for parallel draft functionality.

Applied to files:

  • modelopt/torch/speculative/plugins/megatron_eagle.py
🧬 Code graph analysis (1)
modelopt/torch/speculative/plugins/megatron_eagle.py (1)
modelopt/torch/speculative/plugins/transformers.py (2)
  • _get_eagle_module_inputs (501-548)
  • _eagle_forward (632-657)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
  • GitHub Check: linux
  • GitHub Check: wait-checks / wait
  • GitHub Check: wait-checks / wait
  • GitHub Check: code-quality
  • GitHub Check: build-docs
🔇 Additional comments (8)
modelopt/torch/speculative/plugins/megatron_eagle.py (8)

29-29: LGTM!

The StaticInferenceContext import is correctly added to support KV cache functionality in EAGLE training and inference.


143-143: LGTM!

The docstring clarification from "higher than" to "greater than" improves precision.


817-835: LGTM!

The Eagle module input construction correctly handles:

  • Conditional input_ids/embeddings based on parallel_draft_index
  • Hidden states selection for first vs. subsequent drafts
  • Multi-step attention mask generation
  • Rotary position embeddings concatenation for multiple steps

The logic aligns with the parallel draft architecture described in the PR objectives.


921-954: LGTM!

The _eagle_forward method correctly:

  • Threads inference_context through to the eagle_module
  • Updates sequence_len_offset after processing to maintain KV cache consistency
  • Returns the same values with improved formatting

The KV cache handling is sound and aligns with the learned context about buffer size calculations.

Based on learnings


1009-1013: LGTM!

The KV cache buffer size calculation input_ids.shape[1] * (self.eagle_config.parallel_draft_step + ttt_steps - 1) is correct. Each ttt_step iteration discards (parallel_draft_step - 1) worth of KV entries (line 1115-1117), so only one token's worth of KV cache is retained per ttt_step.

Based on learnings


1042-1155: LGTM!

The forward method training loop correctly implements:

  • Label padding for offline training compatibility (lines 1042-1053)
  • Nested loops for ttt_steps and parallel_draft_step (lines 1062-1087)
  • Per-draft Eagle forward passes with KV cache context (lines 1074-1081)
  • Preservation of first draft's pre-norm hidden states for next iteration (lines 1083-1084), which is the expected behavior per learnings
  • Sequence-parallel gathering/scattering (lines 1089-1110)
  • KV cache offset adjustment to discard unused entries (lines 1115-1117)
  • Loss accumulation with exponential decay (lines 1124-1130)
  • Top-1 accuracy reporting per draft (lines 1132-1146)

The implementation aligns with the PR objectives of parallel draft with auto-regression and efficient KV cache management.

Based on learnings


1385-1391: Verify mask_token_* buffer registration.

Lines 1387-1388 reference getattr(self, f"mask_token_{i}"), which requires the buffers flagged as missing in the earlier review comment (lines 806-815). Ensure those buffers are registered in EagleModule.__init__ before this code can execute correctly.

Otherwise, the logic for adding mask tokens and padding hidden states for parallel draft generation is sound.


1405-1465: LGTM!

The pseudo_speculative_generate method correctly implements parallel draft generation:

  • Proper sequence-parallel handling for embeddings and hidden states (lines 1405-1417)
  • Correct replacement of dummy hidden states with embeddings for mask tokens (lines 1409-1415)
  • Appropriate extraction of draft tokens from multiple positions (lines 1436-1442)
  • Clean removal of mask tokens before concatenating new drafts (lines 1448-1461)

The implementation maintains consistency with the parallel draft architecture and handles tensor operations correctly.

Comment on lines +806 to +815
eagle_inputs["input_ids"] = (
padded_input_ids
if parallel_draft_index == 0
else torch.full(
padded_input_ids.shape,
getattr(self, f"mask_token_{parallel_draft_index - 1}"),
device=padded_input_ids.device,
dtype=padded_input_ids.dtype,
)
if self.config.sequence_parallel:
eagle_inputs["hidden_states"] = scatter_to_sequence_parallel_region(
eagle_inputs["hidden_states"]
)

eagle_inputs["attention_mask"] = set_multi_step_attention_mask(attn_mask, 3)
eagle_inputs["position_ids"] = torch.cat(
(position_ids, position_ids, position_ids), dim=-1
)

if rotary_pos_emb is not None:
eagle_inputs["rotary_pos_emb"] = torch.cat(
(rotary_pos_emb, rotary_pos_emb, rotary_pos_emb),
dim=0,
)
else:
# [TODO] (yeyu): there will be problem here with MLA
eagle_inputs["rotary_pos_emb"] = None
else:
eagle_inputs["input_ids"] = torch.cat(
(padded_input_ids, padded_input_ids, padded_input_ids, padded_input_ids),
dim=-1,
)

if self.config.sequence_parallel:
gathered_hidden_states = gather_from_sequence_parallel_region(hidden_states)
gathered_features = gather_from_sequence_parallel_region(features)
else:
gathered_hidden_states = hidden_states
gathered_features = features
eagle_inputs["hidden_states"] = torch.cat(
(
gathered_hidden_states,
torch.zeros((1, b, h), dtype=hidden_states.dtype, device=hidden_states.device),
gathered_features[:-1, :, :],
),
dim=0,
)
if self.config.sequence_parallel:
eagle_inputs["hidden_states"] = scatter_to_sequence_parallel_region(
eagle_inputs["hidden_states"]
)

eagle_inputs["attention_mask"] = set_multi_step_attention_mask(attn_mask, 4)
eagle_inputs["position_ids"] = torch.cat(
(position_ids, position_ids, position_ids, position_ids), dim=-1
)

if rotary_pos_emb is not None:
eagle_inputs["rotary_pos_emb"] = torch.cat(
(rotary_pos_emb, rotary_pos_emb, rotary_pos_emb, rotary_pos_emb),
dim=0,
)
else:
# [TODO] (yeyu): there will be problem here with MLA
eagle_inputs["rotary_pos_emb"] = None
)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Register mask_token_* buffers in EagleModule.__init__.

Lines 811-814 reference self.mask_token_{parallel_draft_index - 1} attributes, but these buffers are never registered in the EagleModule.__init__ method (lines 369-482). This will cause an AttributeError at runtime when parallel_draft_index > 0.

Add buffer registration in EagleModule.__init__ (after line 482, before the closing of __init__):

        # Register mask token buffers for parallel draft
        if self.config.parallel_draft_step > 1:
            for i in range(self.config.parallel_draft_step - 1):
                # Use a sentinel value like vocab_size or a special token ID
                self.register_buffer(
                    f"mask_token_{i}", 
                    torch.tensor(self.config.vocab_size - 1 - i, dtype=torch.long)
                )

Note: Choose an appropriate mask token ID based on your tokenizer (e.g., a special unused token, or vocab_size - 1).

🤖 Prompt for AI Agents
In modelopt/torch/speculative/plugins/megatron_eagle.py around lines 806-815,
self.mask_token_{i} attributes are referenced but never registered causing
AttributeError when parallel_draft_index > 0; fix by registering those
mask_token_* buffers in EagleModule.__init__ (after line 482) by looping for i
in range(self.config.parallel_draft_step - 1) and calling
self.register_buffer(f"mask_token_{i}", torch.tensor(<chosen_token_id>,
dtype=torch.long)) where <chosen_token_id> is an appropriate sentinel (e.g.
self.config.vocab_size - 1 - i or a dedicated unused token id /
self.config.mask_token_id if available) so the attributes exist as long tensors
on the module.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant