Skip to content

Conversation

h-guo18
Copy link
Contributor

@h-guo18 h-guo18 commented Sep 21, 2025

What does this PR do?

Type of change: new feature

Overview:

This PR improves eagle3 training efficiency with cross attention and flex-attention. Main changes includes:

  • Previously we concat inputs and use self attention during TTT, causing redundant computation and memory allocation. This PR improves it by caching eagle KV states during TTT and use cross attention.
  • Previously using torch.sdpa for attention. This PR change to torch.flex_attention to use its sparse computation and symbolic TTT masks. This gives us additional speedup but requires fixed training length.

See the effect of two above optimizations in the plots below.

Usage

API unchanged, except that we need additionally pass in training_seq_len to the model before training starts. See change in `main.py.

Testing

Efficiency Test

On Llama3.2-1B we are seeing 2x + speedup:

image

And 1.5x memory saving:
image

Correctness Test

Test setting: daring-anteater, tinyllama.
We tried both online and offline and got equivalent ending AR before and after optimization. After optimization it's actually slightly higher. This could be due to the minor numeric error introduced by the padded matmul in previous attention.
image

Before your PR is "Ready for review"

  • Make sure you read and follow Contributor guidelines and your commits are signed.
  • Is this change backward compatible?: Yes/No
  • Did you write any new necessary tests?: Yes/No
  • Did you add or update any necessary documentation?: Yes/No
  • Did you update Changelog?: Yes/No

Additional Information

Summary by CodeRabbit

  • New Features

    • Added option to set a fixed max sequence length for Eagle training/padding via training configuration.
    • Enabled flex attention for Eagle models to improve scalability and throughput.
  • Refactor

    • Streamlined training flow with shared padding lengths and improved cache/mask handling.
    • Simplified loss computation to focus on classification loss.
  • Tests

    • Removed a heavy parameterized test to speed up the test suite while keeping core coverage intact.

Copy link

copy-pr-bot bot commented Sep 21, 2025

Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually.

Contributors can view more details about this message here.

Copy link

coderabbitai bot commented Sep 21, 2025

Walkthrough

Introduces optional max_length propagation through Eagle data collators and call sites, enables flex attention in EagleModule/HF model with block-mask-based TTT handling, refactors Eagle forward to loop over three TTT steps with cached masks and eagle_cache, simplifies loss to classification-only, updates main caller, and removes a heavy unit test.

Changes

Cohort / File(s) Summary
Eagle data module and collators
examples/speculative_decoding/eagle_utils.py, examples/speculative_decoding/main.py
Adds optional max_length to make_eagle_supervised_data_module; threads max_length into DataCollatorWithPadding and offline collator; padding uses self.max_length; main.py passes training_args.training_seq_len as max_length for Eagle modes.
Flex attention + TTT refactor
modelopt/torch/speculative/plugins/transformers.py
Enables flex attention (config._attn_implementation="flex_attention"); introduces BlockMask-based TTT masks with caching; adds _compile_ttt_block_mask; updates _eagle_forward to accept eagle_cache and normalized caches; loops over 3 TTT steps; simplifies _eagle_loss to classification_loss and accuracy; aggregates train_accs.
Tests update
tests/unit/torch/speculative/plugins/test_hf_speculative.py
Removes import torch and deletes test_eagle_model_prepare_eagle_inputs(dtype); retains other tests.

Sequence Diagram(s)

sequenceDiagram
  autonumber
  participant Trainer
  participant Main
  participant DataModule as Eagle Data Module
  participant Model as EagleModel
  participant FlexAttn as Flex Attention

  rect rgba(230,240,255,0.6)
    note right of Main: Configure training
    Trainer->>Main: Build training args (training_seq_len)
    Main->>DataModule: make_eagle_supervised_data_module(max_length=training_seq_len)
    DataModule-->>Main: Collators with fixed max_length
  end

  rect rgba(240,255,240,0.6)
    note right of Model: Forward with TTT loop
    Main->>Model: forward(inputs, eagle_cache=None)
    loop ttt_step in [0,1,2]
      Model->>Model: _get_ttt_attention_mask(seq_len, step)
      Model->>FlexAttn: apply BlockMask (cached/compiled)
      Model->>Model: _eagle_forward(..., eagle_cache)
      Model->>Model: compute classification_loss, accuracy
      Model-->>Model: update eagle_cache
    end
    Model-->>Main: logits, losses, train_accs
  end

  rect rgba(255,245,230,0.6)
    note over Model: Loss = classification only
  end
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Poem

I nibble code like clover leaves,
Three hops through TTT eaves—
Flexing masks in silent thrums,
Cache in paw, the forward hums.
Max-length rows in tidy lines,
Tests trimmed back, the logic shines.
Thump-thump—shipping time!

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 43.48% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title Check ✅ Passed The title succinctly and accurately summarizes the primary change: improving Eagle3 training efficiency by adding an Eagle KV cache and switching to flex attention, which aligns with the PR objectives and the code changes (KV caching, flex_attention enablement, and training_seq_len propagation).
✨ Finishing touches
  • 📝 Generate Docstrings
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch hg/eagle3-crossattn

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@h-guo18 h-guo18 force-pushed the hg/eagle3-crossattn branch 2 times, most recently from 8938a09 to 6759b45 Compare September 21, 2025 02:54
Copy link

copy-pr-bot bot commented Sep 21, 2025

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@h-guo18 h-guo18 force-pushed the hg/eagle3-crossattn branch from 6759b45 to f5835f9 Compare September 21, 2025 03:07
@h-guo18 h-guo18 changed the base branch from main to haoguo/test-offline-ci September 21, 2025 03:36
@h-guo18 h-guo18 changed the title Efficient Eagle3 training with cross attention and flex attention Efficient Eagle3 training with eagle KV cache and flex attention Sep 21, 2025
@h-guo18 h-guo18 self-assigned this Sep 21, 2025
@h-guo18 h-guo18 requested review from ChenhanYu, yeyu-nvidia and benchislett and removed request for benchislett September 22, 2025 00:07
@yeyu-nvidia
Copy link
Contributor

What does base model only mean in the figure?

@yeyu-nvidia
Copy link
Contributor

Can you also do correctness check on online training?

@h-guo18 h-guo18 force-pushed the hg/eagle3-crossattn branch 3 times, most recently from e4a0b13 to 3415d15 Compare September 23, 2025 00:05
@h-guo18 h-guo18 changed the base branch from haoguo/test-offline-ci to main September 23, 2025 00:05
@h-guo18 h-guo18 marked this pull request as ready for review September 23, 2025 00:13
@h-guo18 h-guo18 requested a review from a team as a code owner September 23, 2025 00:13
@h-guo18 h-guo18 requested a review from yeyu-nvidia September 23, 2025 00:13
@h-guo18 h-guo18 force-pushed the hg/eagle3-crossattn branch from d364c25 to 3415d15 Compare September 23, 2025 00:17
Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🧹 Nitpick comments (6)
modelopt/torch/speculative/plugins/transformers.py (6)

39-39: Flex Attention import: ensure environment compatibility.

torch.nn.attention.flex_attention is version‑gated; add a clear error if unavailable.

Apply:

-from torch.nn.attention.flex_attention import BlockMask, create_block_mask
+try:
+    from torch.nn.attention.flex_attention import BlockMask, create_block_mask
+except Exception as e:
+    raise ImportError(
+        "Flex Attention is required for EAGLE TTT. Please upgrade PyTorch to a version that "
+        "provides torch.nn.attention.flex_attention."
+    ) from e

187-189: Attn backend set in two places (flex here, sdpa elsewhere). Unify to avoid confusion.

HFEagleModel.modify later forces sdpa on self.eagle_config; pick one source of truth.

Apply either:

  • Remove/override the sdpa assignment in HFEagleModel.modify, or
  • Gate the choice behind a single config flag.

Example (modify HFEagleModel.modify outside this hunk):

-        self.eagle_config._attn_implementation = "sdpa"
+        self.eagle_config._attn_implementation = "flex_attention"  # keep consistent with EagleModule

454-460: Hardcoded num_ttt_steps=3 and eager compile rely on training_seq_len; make configurable and validate.

Generalize steps and fail early if training_seq_len is missing.

Apply:

-        self.num_ttt_steps = 3  # NOTE: (hg) hardcoded for now. Might add to config later.
-        # compile and cach flex attention masks
-        self.cached_attn_blk_masks = [
-            self._compile_ttt_block_mask(self.eagle_config.training_seq_len, i)
-            for i in range(self.num_ttt_steps)
-        ]
+        self.num_ttt_steps = getattr(self.eagle_config, "num_ttt_steps", 3)
+        if not hasattr(self.eagle_config, "training_seq_len") or self.eagle_config.training_seq_len is None:
+            raise ValueError("eagle_config.training_seq_len must be set when using flex attention TTT.")
+        # compile and cache flex attention masks
+        self.cached_attn_blk_masks = [
+            self._compile_ttt_block_mask(self.eagle_config.training_seq_len, i)
+            for i in range(self.num_ttt_steps)
+        ]

538-572: Mask compiler: add minimal docstrings and boundary assertions.

Helps future maintenance and guards Q/KV bounds for each step.

Apply:

     def _compile_ttt_block_mask(self, seq_length, ttt_step) -> BlockMask:
-        """Compile TTT attention_masks with symbolic masks and return a BlockMask object for flex attention."""
+        """Compile per‑step TTT BlockMask for flex attention.
+        Assumes fixed training seq length and KV cache layout per step."""
+        assert seq_length > 0, "seq_length must be positive"

720-721: Lazily recompile masks if the actual sequence length drifts from configured training_seq_len.

Avoids hard failure when users change max_length without rebuilding the model.

Apply:

-            b, seq_length, h = base_model_hidden_states.shape
+            b, seq_length, h = base_model_hidden_states.shape
+            if getattr(self.eagle_config, "training_seq_len", None) != seq_length:
+                # Recompile masks on‑the‑fly
+                self.cached_attn_blk_masks = [
+                    self._compile_ttt_block_mask(seq_length, i) for i in range(self.num_ttt_steps)
+                ]
+                self.eagle_config.training_seq_len = int(seq_length)

785-803: Loss masking alignment looks correct. Normalize mask dtype to avoid type‑promotion surprises.

Apply:

-                classification_loss, acc = self._eagle_loss(
+                classification_loss, acc = self._eagle_loss(
                     # base model predict +1 tok, while eagle predict +2
                     # so we shift base model outputs compared to eagle outputs
                     base_model_logits[:, 1:],
                     eagle_logits[:, :-1],
                     # additionally, we mask the first n tok of eagle outputs at nth TTT step
-                    torch.cat(
+                    torch.cat(
                         (
                             torch.zeros(
-                                b, 1 + ttt_step, dtype=loss_mask.dtype, device=loss_mask.device
+                                b, 1 + ttt_step, dtype=loss_mask.dtype, device=loss_mask.device
                             ),
                             loss_mask[:, 2 + ttt_step :],
                         ),
                         dim=1,
                     ),
                 )

And inside _eagle_loss (outside this hunk) cast the mask once:

-        loss_mask = loss_mask[:, :, None]
+        loss_mask = loss_mask.to(eagle_logits.dtype)[:, :, None]
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between add61db and 3415d15.

📒 Files selected for processing (3)
  • examples/speculative_decoding/eagle_utils.py (5 hunks)
  • examples/speculative_decoding/main.py (2 hunks)
  • modelopt/torch/speculative/plugins/transformers.py (11 hunks)
🧰 Additional context used
🧬 Code graph analysis (3)
examples/speculative_decoding/eagle_utils.py (1)
modelopt/torch/utils/logging.py (1)
  • print_rank_0 (92-95)
modelopt/torch/speculative/plugins/transformers.py (1)
modelopt/torch/speculative/plugins/megatron_eagle.py (1)
  • _eagle_forward (1156-1179)
examples/speculative_decoding/main.py (1)
examples/speculative_decoding/eagle_utils.py (1)
  • make_eagle_supervised_data_module (238-315)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: wait-checks / wait
  • GitHub Check: wait-checks / wait
🔇 Additional comments (7)
examples/speculative_decoding/eagle_utils.py (3)

239-243: API extension looks good and is backward‑compatible.

Optional max_length propagated cleanly.


319-321: Constructor OK. Consider documenting that sequences longer than max_length will be truncated (see below fix).


364-366: OK to forward max_length to offline collator.

Matches online path behavior.

modelopt/torch/speculative/plugins/transformers.py (3)

710-714: Good: normalize legacy caches.

Prevents attribute errors downstream.


771-784: OK: reuse inputs_embeds across TTT passes. Verify intended semantics.

If the first layer concatenates normalized embeds with shifted prenorm states, keeping inputs_embeds fixed is correct; otherwise, recompute per pass.


821-822: train_acc API change (list) — check downstream consumers.

Found only this assignment in repo: modelopt/torch/speculative/plugins/transformers.py:819-823 — train_acc = train_accs. Search for any callers/metrics/logging that assume a scalar and either aggregate the list here or update callers to accept a list.

examples/speculative_decoding/main.py (1)

232-234: Approve: max_length is passed and used for Eagle runs

Confirmed: main.py passes training_seq_len into make_eagle_supervised_data_module; examples/speculative_decoding/eagle_utils.py instantiates DataCollatorForOffline(max_length=...) and DataCollatorWithPadding(max_length=...) (DataCollatorForOffline forwards max_length). No other Eagle collator bypass found.

Locations: examples/speculative_decoding/main.py:232-234; examples/speculative_decoding/eagle_utils.py:301,309,363

Comment on lines 373 to 374
max_hs_length = (
self.max_length
if self.max_length is not None
else max(item["base_model_hidden_states"].shape[0] for item in features)
)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Apply the same truncation guard for hidden states in offline mode.

Mirror the online clamp to avoid negative padding on preprocessed HS.

Apply:

-        max_hs_length = (
+        max_hs_length = (
             self.max_length
             if self.max_length is not None
             else max(item["base_model_hidden_states"].shape[0] for item in features)
         )
+        # Clamp oversize hidden states
+        for item in features:
+            if item["base_model_hidden_states"].shape[0] > max_hs_length:
+                item["base_model_hidden_states"] = item["base_model_hidden_states"][:max_hs_length]
+                item["aux_hidden_states"] = item["aux_hidden_states"][:max_hs_length]
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
max_hs_length = (
self.max_length
if self.max_length is not None
else max(item["base_model_hidden_states"].shape[0] for item in features)
)
max_hs_length = (
self.max_length
if self.max_length is not None
else max(item["base_model_hidden_states"].shape[0] for item in features)
)
# Clamp oversize hidden states
for item in features:
if item["base_model_hidden_states"].shape[0] > max_hs_length:
item["base_model_hidden_states"] = item["base_model_hidden_states"][:max_hs_length]
item["aux_hidden_states"] = item["aux_hidden_states"][:max_hs_length]
🤖 Prompt for AI Agents
In examples/speculative_decoding/eagle_utils.py around lines 373 to 377, the
offline path sets max_hs_length to self.max_length (if not None) which can yield
negative padding when preprocessed hidden states are longer; compute the actual
maximum hidden-state length from features and ensure the final max_hs_length is
at least that value by replacing the assignment with a guarded clamp (e.g.,
compute computed_max = max(item["base_model_hidden_states"].shape[0] for item in
features) and then set max_hs_length = max(self.max_length or 0, computed_max)),
so offline mirrors the online clamp and avoids negative pad lengths.

Comment on lines 206 to 208
# pass in the seq length for flex attention mask compilation
"training_seq_len": training_args.training_seq_len,
}
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Good: plumbs fixed seq length into Eagle config for flex masks. Add a hard guard.

Protect users from silent shape mismatches by validating against base model limits here.

Apply:

                 config["eagle_architecture_config"].update(
                     {
                         "hidden_size": model.config.hidden_size,
                         "vocab_size": model.config.vocab_size,
                         # we also overwrite max_pos_embedding for deployment compatibility
                         "max_position_embeddings": model.config.max_position_embeddings,
                         "draft_vocab_size": custom_config["draft_vocab_size"]
                         if eagle_args.eagle_config and "draft_vocab_size" in custom_config
                         else model.config.vocab_size,
                         # pass in the seq length for flex attention mask compilation
-                        "training_seq_len": training_args.training_seq_len,
+                        "training_seq_len": int(training_args.training_seq_len),
                     }
                 )
+            # Sanity: training_seq_len must not exceed base model's max positions
+            if training_args.training_seq_len > model.config.max_position_embeddings:
+                raise ValueError(
+                    f"training_seq_len ({training_args.training_seq_len}) exceeds "
+                    f"base model max_position_embeddings ({model.config.max_position_embeddings})."
+                )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
# pass in the seq length for flex attention mask compilation
"training_seq_len": training_args.training_seq_len,
}
config["eagle_architecture_config"].update(
{
"hidden_size": model.config.hidden_size,
"vocab_size": model.config.vocab_size,
# we also overwrite max_pos_embedding for deployment compatibility
"max_position_embeddings": model.config.max_position_embeddings,
"draft_vocab_size": custom_config["draft_vocab_size"]
if eagle_args.eagle_config and "draft_vocab_size" in custom_config
else model.config.vocab_size,
# pass in the seq length for flex attention mask compilation
"training_seq_len": int(training_args.training_seq_len),
}
)
# Sanity: training_seq_len must not exceed base model's max positions
if training_args.training_seq_len > model.config.max_position_embeddings:
raise ValueError(
f"training_seq_len ({training_args.training_seq_len}) exceeds "
f"base model max_position_embeddings ({model.config.max_position_embeddings})."
)
🤖 Prompt for AI Agents
In examples/speculative_decoding/main.py around lines 206 to 208, the code
passes training_args.training_seq_len into the Eagle config for flex attention
masks but lacks a hard guard against exceeding the base model's maximum context
length; add a validation right before building the config that reads the base
model's max sequence length (e.g., model.config.max_position_embeddings or
model.config.max_length depending on model type), verify
training_args.training_seq_len is not None and <= that max, and if it is greater
(or missing) raise a clear exception (ValueError) describing the mismatch and
advising the user to reduce training_seq_len or use a larger model so silent
shape mismatches are prevented.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
modelopt/torch/speculative/plugins/transformers.py (2)

731-737: Mixing 4‑D attn masks with flex attention can break; use BlockMask for the first pass too.

Step‑0 currently builds a 4‑D mask via _prepare_decoder_attention_mask while the subsequent TTT passes use BlockMask. Under flex_attention, unify to BlockMask for consistency and perf.

-            eagle_input_ids, attention_mask_0, position_ids = self._get_eagle_module_inputs(
+            eagle_input_ids, _attention_mask_unused, position_ids = self._get_eagle_module_inputs(
                 input_ids,
                 eagle_input_hidden_states,
                 attention_mask,
                 position_ids,
                 eagle_cache,
             )
@@
-            _, eagle_prenorm_h, eagle_logits, eagle_cache = self._eagle_forward(
+            # Use the precompiled BlockMask for step 0
+            blk_masks = _get_blk_masks(seq_length)
+            _, eagle_prenorm_h, eagle_logits, eagle_cache = self._eagle_forward(
                 eagle_input_hidden_states,
                 inputs_embeds,
-                attention_mask_0,
+                blk_masks[0],
                 position_ids,
                 position_embeddings,
                 eagle_cache,
             )

Also applies to: 743-750


503-507: Avoid hardcoding token id 0 when shifting input_ids. Use pad/eos.

Using 0 risks feeding BOS for models where pad_token_id ≠ 0. Prefer pad_token_id or eos as fallback.

-        zeropadding = torch.zeros(
-            input_ids.shape[0], 1, dtype=input_ids.dtype, device=input_ids.device
-        )
-        eagle_input_ids = torch.cat((input_ids[:, 1:], zeropadding), dim=1)
+        pad_id = getattr(self.model.config, "pad_token_id", None)
+        if pad_id is None:
+            pad_id = getattr(self.model.config, "eos_token_id", 0)
+        pad_col = torch.full((input_ids.shape[0], 1), pad_id, dtype=input_ids.dtype, device=input_ids.device)
+        eagle_input_ids = torch.cat((input_ids[:, 1:], pad_col), dim=1)
🧹 Nitpick comments (5)
examples/speculative_decoding/main.py (1)

232-234: Pass max_length to the Eagle data module — necessary for fixed-length training.

This ensures batches are padded/truncated to the mask-compile length. One ask: guard against accidental mismatch.

Consider asserting consistency at runtime to fail fast:

@@
-        data_module = make_eagle_supervised_data_module(
-            tokenizer, data_args, use_offline_training, max_length=training_args.training_seq_len
-        )
+        data_module = make_eagle_supervised_data_module(
+            tokenizer, data_args, use_offline_training, max_length=training_args.training_seq_len
+        )
+        assert training_args.dataloader_drop_last, \
+            "Set TrainingArguments.dataloader_drop_last=True to keep fixed batch shapes for flex attention."
examples/speculative_decoding/eagle_utils.py (2)

319-321: Small safety: clamp rather than assume max_length ≥ per-sample length.

Defensive guard avoids negative-size paddings if a caller passes a too-small max_length.

@@ class DataCollatorWithPadding:
-    def __init__(self, max_length=None):
-        self.max_length = max_length
+    def __init__(self, max_length=None):
+        self.max_length = max_length
@@
-        max_length = (
+        max_length = (
             self.max_length
             if self.max_length is not None
             else max(item["input_ids"].shape[0] for item in features)
         )
+        # Clamp to the longest in batch to avoid negative paddings if misconfigured
+        if self.max_length is not None:
+            max_length = max(max_length, max(item["input_ids"].shape[0] for item in features))

Also applies to: 334-338


364-366: Mirror clamp for hidden-state padding in offline mode.

Avoids negative paddings if hidden_state length > max_length due to stale .pt files.

@@ class DataCollatorForOffline(DataCollatorWithPadding):
-        max_hs_length = (
+        max_hs_length = (
             self.max_length
             if self.max_length is not None
             else max(item["base_model_hidden_states"].shape[0] for item in features)
         )
+        if self.max_length is not None:
+            max_hs_length = max(
+                max_hs_length, max(item["base_model_hidden_states"].shape[0] for item in features)
+            )

Also applies to: 373-377

modelopt/torch/speculative/plugins/transformers.py (2)

538-572: Document and validate mask semantics; add basic bounds checks.

The mask lambdas are dense; at minimum assert shapes and add comments clarifying which KV block each equality term targets. Also sanity check KV_LEN multiples.

Add inline comments like:

-            def msk(b, h, q_idx, kv_idx):
-                # attention mask of shape [seq_len, 3* seq_len] for TTT step 1
+            def msk(b, h, q_idx, kv_idx):
+                # Step 1: allow causal prefix (<= q-2), and cross links to positions
+                # (q + L - 1) and (q + 2L) in the concatenated KV stream ([L | L | L])

And before returning:

+        assert seq_length > 0, "seq_length must be positive for BlockMask"

712-714: Safe conversion of eagle_cache only when not None.

DynamicCache.from_legacy_cache(None) can allocate an empty cache and mask logic may assume “no cache”. Keep None semantics until the first _eagle_forward.

-        if not isinstance(eagle_cache, Cache):
-            eagle_cache = DynamicCache.from_legacy_cache(eagle_cache)
+        if eagle_cache is not None and not isinstance(eagle_cache, Cache):
+            eagle_cache = DynamicCache.from_legacy_cache(eagle_cache)
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between add61db and 3415d15.

📒 Files selected for processing (3)
  • examples/speculative_decoding/eagle_utils.py (5 hunks)
  • examples/speculative_decoding/main.py (2 hunks)
  • modelopt/torch/speculative/plugins/transformers.py (11 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
examples/speculative_decoding/main.py (1)
examples/speculative_decoding/eagle_utils.py (1)
  • make_eagle_supervised_data_module (238-315)
modelopt/torch/speculative/plugins/transformers.py (1)
modelopt/torch/speculative/plugins/megatron_eagle.py (1)
  • _eagle_forward (1156-1179)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
  • GitHub Check: linux
  • GitHub Check: wait-checks / wait
  • GitHub Check: wait-checks / wait
🔇 Additional comments (4)
examples/speculative_decoding/eagle_utils.py (2)

239-243: API: Adding max_length to make_eagle_supervised_data_module is the right call.

Keeps padding logic and mask compilation in sync.


301-301: Correct: Collators now honor a fixed max_length.

This avoids shape drift across batches and matches flex-attn BlockMask expectations.

Also applies to: 309-309

modelopt/torch/speculative/plugins/transformers.py (1)

821-822: train_acc returned as a list; confirm logging/serialization.

HF Trainer won’t auto-log lists. Either aggregate (mean) or expose step‑wise metrics under distinct keys.

Consider returning a scalar avg to avoid JSON serialization surprises:

-            train_acc=train_accs,
+            train_acc=sum(train_accs) / max(len(train_accs), 1),

If you want per‑step, emit in logs within training loop instead.

examples/speculative_decoding/main.py (1)

206-208: Approve — propagate training_seq_len into eagle_architecture_config (verify remaining tokenizer usage)

  • Mask compilation uses eagle_config.training_seq_len (modelopt/torch/speculative/plugins/transformers.py:457).
  • examples/speculative_decoding/eagle_utils.py still uses tokenizer.model_max_length for offline-data cropping (examples/speculative_decoding/eagle_utils.py:215); confirm if that should use eagle_config.training_seq_len instead.

import torch
from torch import nn
from torch.nn import CrossEntropyLoss
from torch.nn.attention.flex_attention import BlockMask, create_block_mask
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Version‑gating for flex_attention import missing.

torch.nn.attention.flex_attention is only available in recent PyTorch; importing unguarded will crash older envs. Add a soft fallback to SDPA.

-from torch.nn.attention.flex_attention import BlockMask, create_block_mask
+try:
+    from torch.nn.attention.flex_attention import BlockMask, create_block_mask  # PyTorch >= supports flex attention
+    _FLEX_ATTENTION_AVAILABLE = True
+except Exception:
+    BlockMask = None
+    create_block_mask = None
+    _FLEX_ATTENTION_AVAILABLE = False

And down below (see next comments) gate usage accordingly.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
from torch.nn.attention.flex_attention import BlockMask, create_block_mask
try:
from torch.nn.attention.flex_attention import BlockMask, create_block_mask # PyTorch >= supports flex attention
_FLEX_ATTENTION_AVAILABLE = True
except Exception:
BlockMask = None
create_block_mask = None
_FLEX_ATTENTION_AVAILABLE = False
🤖 Prompt for AI Agents
In modelopt/torch/speculative/plugins/transformers.py around line 39, the
unguarded import from torch.nn.attention.flex_attention will crash on older
PyTorch; wrap that import in a try/except ImportError (or AttributeError) and
set a module-level flag like HAS_FLEX_ATTENTION = True/False, and on failure
import or reference the standard SDPA utilities (e.g.,
torch.nn.functional.scaled_dot_product_attention) and provide minimal fallback
implementations or aliases for BlockMask/create_block_mask (or set them to None
and document behavior). Then update all downstream usage to check
HAS_FLEX_ATTENTION before calling flex_attention APIs and branch to the SDPA
fallback implementation when False.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 3415d15 and 3531467.

📒 Files selected for processing (3)
  • examples/speculative_decoding/eagle_utils.py (5 hunks)
  • examples/speculative_decoding/main.py (2 hunks)
  • modelopt/torch/speculative/plugins/transformers.py (11 hunks)
🚧 Files skipped from review as they are similar to previous changes (2)
  • examples/speculative_decoding/main.py
  • examples/speculative_decoding/eagle_utils.py
🧰 Additional context used
🧬 Code graph analysis (1)
modelopt/torch/speculative/plugins/transformers.py (1)
modelopt/torch/speculative/plugins/megatron_eagle.py (1)
  • _eagle_forward (1156-1179)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
  • GitHub Check: linux
  • GitHub Check: wait-checks / wait
  • GitHub Check: wait-checks / wait
  • GitHub Check: build-docs
  • GitHub Check: code-quality
🔇 Additional comments (12)
modelopt/torch/speculative/plugins/transformers.py (12)

319-319: Correct use of past_key_value parameter.

This matches HF LlamaDecoderLayer’s signature.


620-637: Eagle cache threading through _eagle_forward looks correct.


742-750: Confirm flex path accepts 4D attention_mask for the first pass.

With _attn_implementation="flex_attention", ensure HF LlamaAttention supports a 4D float mask here; otherwise, convert the first pass to a BlockMask too.

If needed, we can add a simple step-(-1) BlockMask equivalent of the standard causal mask for the first pass.


755-762: Loss/acc aggregation change LGTM.


821-821: Returning train_acc list is fine; callers should expect a per-step vector.

Ensure downstream logging expects a list; otherwise, summarize (e.g., mean) before returning.


846-846: Loss return shape/signature change LGTM.


39-39: Version‑gate flex_attention import and provide a fallback flag.

Unguarded import will crash on older PyTorch lacking flex_attention. Add a try/except and a feature flag; downstream usage should branch accordingly.

-from torch.nn.attention.flex_attention import BlockMask, create_block_mask
+try:
+    from torch.nn.attention.flex_attention import BlockMask, create_block_mask  # PyTorch >= 2.5
+    _FLEX_ATTENTION_AVAILABLE = True
+except Exception:
+    BlockMask = None
+    create_block_mask = None
+    _FLEX_ATTENTION_AVAILABLE = False

186-189: Don’t force flex_attention when unavailable; fall back to SDPA.

This hard-sets flex attention and will crash if the import failed. Gate on the feature flag.

-        # Use flex attention for efficient TTT
-        config._attn_implementation = "flex_attention"
+        # Prefer flex attention for efficient TTT when available
+        config._attn_implementation = "flex_attention" if _FLEX_ATTENTION_AVAILABLE else "sdpa"

454-460: Make TTT steps configurable and mask compilation resilient to seq_len mismatches.

Hardcoding steps and precompiling a single seq_len can break when batches end up shorter (e.g., last batch) or configs change. Cache masks per observed seq_len and read steps from config with a default.

-        self.num_ttt_steps = 3  # NOTE: (hg) hardcoded for now. Might add to config later.
-        # compile and cach flex attention masks
-        self.cached_attn_blk_masks = [
-            self._compile_ttt_block_mask(self.eagle_config.training_seq_len, i)
-            for i in range(self.num_ttt_steps)
-        ]
+        self.num_ttt_steps = getattr(self.eagle_config, "num_ttt_steps", 3)
+        # Precompile for configured training_seq_len; keep a runtime cache for other lengths.
+        self._blk_mask_cache: dict[int, list] = {}
+        def _compile_set(L: int):
+            self._blk_mask_cache[L] = [
+                self._compile_ttt_block_mask(L, i) for i in range(self.num_ttt_steps)
+            ]
+        if _FLEX_ATTENTION_AVAILABLE:
+            _compile_set(self.eagle_config.training_seq_len)

Add this helper in the class (nearby is fine):

def _get_blk_masks(self, L: int):
    if not _FLEX_ATTENTION_AVAILABLE:
        return None
    if L not in self._blk_mask_cache:
        # Compile on-demand for unexpected lengths (e.g., short last batch)
        self._blk_mask_cache[L] = [
            self._compile_ttt_block_mask(L, i) for i in range(self.num_ttt_steps)
        ]
    return self._blk_mask_cache[L]

538-571: Explain the TTT BlockMask semantics and guard when flex_attention is unavailable.

This is non-trivial logic; add a concise explanation and an availability check to fail fast with a clear error when flex isn’t present.

-    def _compile_ttt_block_mask(self, seq_length, ttt_step) -> BlockMask:
-        """Compile TTT attention_masks with symbolic masks and return a BlockMask object for flex attention."""
+    def _compile_ttt_block_mask(self, seq_length, ttt_step) -> BlockMask:
+        """Compile symbolic BlockMask for TTT step.
+
+        Semantics:
+          - Q_LEN = seq_length (current step tokens).
+          - KV_LEN = past_len + seq_length, where past_len accumulates per TTT step.
+          - We allow sparse connections that emulate the concat+shift behavior without materializing buffers:
+            * A limited causal window into the previous KV blocks.
+            * A small number of aligned "skip" positions into recent blocks to propagate predictions.
+        """
+        if not _FLEX_ATTENTION_AVAILABLE:
+            raise RuntimeError(
+                "flex_attention is required for TTT BlockMask; install a compatible PyTorch or set _attn_implementation='sdpa'."
+            )
         if ttt_step == 0:
 
             def msk(b, h, q_idx, kv_idx):
                 # symbolic attention mask of shape [seq_len, 2* seq_len] for TTT step 0
                 return (kv_idx <= (q_idx - 1)) | (kv_idx == q_idx + seq_length)
 
             return create_block_mask(msk, B=None, H=None, Q_LEN=seq_length, KV_LEN=seq_length * 2)
         elif ttt_step == 1:
 
             def msk(b, h, q_idx, kv_idx):
                 # attention mask of shape [seq_len, 3* seq_len] for TTT step 1
                 return (
                     (kv_idx <= (q_idx - 2))
                     | ((kv_idx == q_idx + seq_length - 1) & (kv_idx >= seq_length))
                     | ((kv_idx == q_idx + 2 * seq_length) & (kv_idx >= seq_length * 2))
                 )
 
             return create_block_mask(msk, B=None, H=None, Q_LEN=seq_length, KV_LEN=seq_length * 3)
         elif ttt_step == 2:
 
             def msk(b, h, q_idx, kv_idx):
                 # attention mask of shape [seq_len, 4* seq_len] for TTT step 2
                 return (
                     (kv_idx <= (q_idx - 3))
                     | ((kv_idx == q_idx + seq_length - 2) & (kv_idx >= seq_length))
                     | ((kv_idx == q_idx + 2 * seq_length - 1) & (kv_idx >= seq_length * 2))
                     | ((kv_idx == q_idx + 3 * seq_length) & (kv_idx >= seq_length * 3))
                 )
 
             return create_block_mask(msk, B=None, H=None, Q_LEN=seq_length, KV_LEN=seq_length * 4)
         else:
             raise ValueError(f"EAGLE TTT step {ttt_step} is not supported")

Also, please add a short inline comment mapping kv_idx ranges to blocks (past/current/next) to address the prior reviewer ask.


764-803: Guard masks by actual seq_length at runtime.

Use masks matching the observed sequence length; fallback-compile on-the-fly when it differs from the configured training_seq_len.

-            for ttt_step in range(self.num_ttt_steps):
+            for ttt_step in range(self.num_ttt_steps):
                 eagle_input_hidden_states = torch.cat(
                     (
                         torch.zeros(
                             (b, 1, h),
                             dtype=eagle_input_hidden_states.dtype,
                             device=eagle_input_hidden_states.device,
                         ),
                         eagle_prenorm_h[:, :-1, :],
                     ),
                     dim=1,
                 )
-                attention_mask = self.cached_attn_blk_masks[ttt_step]
+                # Fetch a BlockMask matching the current seq length (compile on-demand if needed)
+                seq_length = eagle_input_hidden_states.size(1)
+                blk_masks = self._get_blk_masks(seq_length) if hasattr(self, "_get_blk_masks") else (
+                    self.cached_attn_blk_masks
+                    if seq_length == self.eagle_config.training_seq_len
+                    else [self._compile_ttt_block_mask(seq_length, i) for i in range(self.num_ttt_steps)]
+                )
+                attention_mask = blk_masks[ttt_step]
                 _, eagle_prenorm_h, eagle_logits, eagle_cache = self._eagle_forward(
                     eagle_input_hidden_states,
                     inputs_embeds,
                     attention_mask,
                     position_ids,
                     position_embeddings,
                     eagle_cache,
                 )

454-460: Confirmed: training_seq_len is propagated and dataloader drop_last enforced for EAGLE training

  • training_seq_len (TrainingArguments) is passed to tokenizer.model_max_length and injected into eagle_config during mtsp.convert (examples/speculative_decoding/main.py).
  • DataCollatorWithPadding / DataCollatorForOffline pad/truncate to the provided max_length so batches are fixed-length (examples/speculative_decoding/eagle_utils.py).
  • Offline samples are truncated to tokenizer.model_max_length when loaded.
  • Trainer is created with training_args (dataloader_drop_last default True and launch scripts pass True), so incomplete batches are dropped — precompiled cached_attn_blk_masks (modelopt/torch/speculative/plugins/transformers.py) will match the padded batch length.

Comment on lines 710 to 718
if not isinstance(past_key_values, Cache):
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
if not isinstance(eagle_cache, Cache):
eagle_cache = DynamicCache.from_legacy_cache(eagle_cache)

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick

Avoid converting None eagle_cache via DynamicCache.from_legacy_cache(None).

Guard the conversion; some HF versions don’t treat None as legacy cache.

-        if not isinstance(eagle_cache, Cache):
-            eagle_cache = DynamicCache.from_legacy_cache(eagle_cache)
+        if eagle_cache is not None and not isinstance(eagle_cache, Cache):
+            eagle_cache = DynamicCache.from_legacy_cache(eagle_cache)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if not isinstance(past_key_values, Cache):
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
if not isinstance(eagle_cache, Cache):
eagle_cache = DynamicCache.from_legacy_cache(eagle_cache)
if not isinstance(past_key_values, Cache):
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
if eagle_cache is not None and not isinstance(eagle_cache, Cache):
eagle_cache = DynamicCache.from_legacy_cache(eagle_cache)
🤖 Prompt for AI Agents
In modelopt/torch/speculative/plugins/transformers.py around lines 710 to 714,
avoid calling DynamicCache.from_legacy_cache(None) for eagle_cache; update the
guard so you only convert when eagle_cache is not None and not already a Cache
(e.g. if eagle_cache is not None and not isinstance(eagle_cache, Cache):
eagle_cache = DynamicCache.from_legacy_cache(eagle_cache)), leaving None
untouched; apply the same None check for past_key_values if needed.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 3531467 and 138c806.

📒 Files selected for processing (7)
  • examples/speculative_decoding/main.py (2 hunks)
  • modelopt/torch/speculative/config.py (1 hunks)
  • modelopt/torch/speculative/eagle/conversion.py (1 hunks)
  • modelopt/torch/speculative/eagle/eagle_model.py (2 hunks)
  • modelopt/torch/speculative/plugins/megatron_eagle.py (2 hunks)
  • modelopt/torch/speculative/plugins/transformers.py (13 hunks)
  • tests/unit/torch/speculative/plugins/test_hf_speculative.py (0 hunks)
💤 Files with no reviewable changes (1)
  • tests/unit/torch/speculative/plugins/test_hf_speculative.py
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-09-05T19:10:36.393Z
Learnt from: yeyu-nvidia
PR: NVIDIA/TensorRT-Model-Optimizer#295
File: tests/gpu/torch/speculative/plugins/test_speculative_megatron_modules.py:35-39
Timestamp: 2025-09-05T19:10:36.393Z
Learning: In the TensorRT-Model-Optimizer codebase, both EAGLE1_DEFAULT_CFG and EAGLE3_DEFAULT_CFG in config.py already use deepcopied configurations (either directly via deepcopy() or through variables that were already deepcopied), making additional deepcopy calls unnecessary when using ALGO_TO_CONFIG mapping in tests.

Applied to files:

  • modelopt/torch/speculative/eagle/conversion.py
  • modelopt/torch/speculative/plugins/transformers.py
🧬 Code graph analysis (3)
modelopt/torch/speculative/config.py (1)
modelopt/torch/opt/config.py (1)
  • ModeloptField (50-53)
modelopt/torch/speculative/plugins/transformers.py (1)
modelopt/torch/speculative/plugins/megatron_eagle.py (1)
  • _eagle_forward (1158-1181)
examples/speculative_decoding/main.py (1)
examples/speculative_decoding/eagle_utils.py (1)
  • make_eagle_supervised_data_module (238-315)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
  • GitHub Check: linux
  • GitHub Check: wait-checks / wait
  • GitHub Check: wait-checks / wait
  • GitHub Check: code-quality
  • GitHub Check: build-docs
🔇 Additional comments (11)
modelopt/torch/speculative/eagle/conversion.py (1)

50-52: LGTM: threads training_seq_len into Eagle modify path.
No issues spotted; signature alignment with downstream modify implementations looks correct.

modelopt/torch/speculative/eagle/eagle_model.py (1)

38-39: LGTM: persists eagle_training_seq_len on the model.
Storage and signature change are consistent with conversion and plugins.

Also applies to: 49-49

modelopt/torch/speculative/plugins/megatron_eagle.py (1)

724-748: LGTM: correctly forwards eagle_training_seq_len through modify.
No functional concerns here.

examples/speculative_decoding/main.py (2)

234-236: Ensure collators enforce fixed length (pad/truncate) for flex/TTT.
You pass max_length; verify both online and offline collators pad to exactly training_seq_len.

If either collator only “caps” length without right‑padding to exact size, flex masks will misalign. Please confirm DataCollatorWithPadding and DataCollatorForOffline always output sequences of length == max_length.


188-193: Hard‑guard training_seq_len against base model max positions.
Without a check, flex/TTT masks can mismatch and crash. Validate before updating config.

Apply:

-            # overwrite config with custom config
+            # overwrite config with custom config
+            # Sanity: training_seq_len must not exceed base model's max positions
+            max_pos = getattr(model.config, "max_position_embeddings", None)
+            if max_pos is not None and training_args.training_seq_len > max_pos:
+                raise ValueError(
+                    f"training_seq_len ({training_args.training_seq_len}) exceeds "
+                    f"base model max_position_embeddings ({max_pos})."
+                )
             config.update(
                 {
                     "eagle_offline": use_offline_training,
                     "eagle_training_seq_len": training_args.training_seq_len,
                 }
             )
modelopt/torch/speculative/plugins/transformers.py (6)

540-574: Generalize masks to arbitrary ttt_step or document limit.
Either parametrize to N steps or state the 3‑step assumption in config/docs.

If you want, I can refactor _compile_ttt_block_mask to generate the pattern programmatically for any N.


456-462: Precompiled masks need seq‑len awareness; add a runtime cache.
Current code precompiles for a single length; any drift causes mismatch. Cache by observed seq_len.

Also fix the comment typo (“cach” -> “cache”).
Apply:

-        self.num_ttt_steps = 3  # NOTE: (hg) hardcoded for now. Might add to config later.
-        # compile and cach flex attention masks
-        self.cached_attn_blk_masks = [
-            self._compile_ttt_block_mask(eagle_training_seq_len, i)
-            for i in range(self.num_ttt_steps)
-        ]
+        self.num_ttt_steps = getattr(self.eagle_config, "num_ttt_steps", 3)  # default 3
+        # Precompile for configured training_seq_len; keep a runtime cache for other lengths.
+        self._blk_mask_cache = {
+            eagle_training_seq_len: [
+                self._compile_ttt_block_mask(eagle_training_seq_len, i)
+                for i in range(self.num_ttt_steps)
+            ]
+        }

Add this helper (outside the shown hunk):

def _get_blk_masks(self, seq_len):
    if seq_len not in self._blk_mask_cache:
        self._blk_mask_cache[seq_len] = [
            self._compile_ttt_block_mask(seq_len, i) for i in range(self.num_ttt_steps)
        ]
    return self._blk_mask_cache[seq_len]

714-716: Avoid converting None with DynamicCache.from_legacy_cache.
Guard None to prevent HF cache utils from erroring.

Apply:

-        if not isinstance(eagle_cache, Cache):
-            eagle_cache = DynamicCache.from_legacy_cache(eagle_cache)
+        if eagle_cache is not None and not isinstance(eagle_cache, Cache):
+            eagle_cache = DynamicCache.from_legacy_cache(eagle_cache)

766-786: Fetch TTT masks for the actual seq_len during training‑time‑testing.
Use the runtime cache to avoid mismatches when seq_len varies.

Apply:

-            for ttt_step in range(self.num_ttt_steps):
+            for ttt_step in range(self.num_ttt_steps):
                 eagle_input_hidden_states = torch.cat(
                     (
                         torch.zeros(
                             (b, 1, h),
                             dtype=eagle_input_hidden_states.dtype,
                             device=eagle_input_hidden_states.device,
                         ),
                         eagle_prenorm_h[:, :-1, :],
                     ),
                     dim=1,
                 )
-                attention_mask = self.cached_attn_blk_masks[ttt_step]
+                attention_mask = self._get_blk_masks(seq_length)[ttt_step]
                 _, eagle_prenorm_h, eagle_logits, eagle_cache = self._eagle_forward(
                     eagle_input_hidden_states,
                     inputs_embeds,
                     attention_mask,
                     position_ids,
                     position_embeddings,
                     eagle_cache,
                 )

Also applies to: 787-805


39-39: Guard flex_attention import; fall back gracefully when unavailable.
Unconditional import will crash on older PyTorch. Add a try/except and a feature flag.

Apply:

-from torch.nn.attention.flex_attention import BlockMask, create_block_mask
+try:
+    from torch.nn.attention.flex_attention import BlockMask, create_block_mask  # PyTorch w/ flex attention
+    _FLEX_ATTENTION_AVAILABLE = True
+except Exception:
+    BlockMask = None
+    create_block_mask = None
+    _FLEX_ATTENTION_AVAILABLE = False

187-189: Don’t force flex_attention if it’s not available.
Prefer flex when present; otherwise use SDPA.

Apply:

-        # Use flex attention for efficient TTT
-        config._attn_implementation = "flex_attention"
+        # Prefer flex attention for efficient TTT; fall back to SDPA if unavailable
+        config._attn_implementation = "flex_attention" if _FLEX_ATTENTION_AVAILABLE else "sdpa"

Comment on lines 98 to 101
eagle_training_seq_len: int = ModeloptField(
default=1024, description=("The training sequence length.")
)

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick

Validate eagle_training_seq_len is positive (and small integers only).
Add a lower bound to avoid accidental 0/negative values.

Apply:

-    eagle_training_seq_len: int = ModeloptField(
-        default=1024, description=("The training sequence length.")
-    )
+    eagle_training_seq_len: int = ModeloptField(
+        default=1024,
+        description=("The training sequence length."),
+        ge=1,
+    )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
eagle_training_seq_len: int = ModeloptField(
default=1024, description=("The training sequence length.")
)
eagle_training_seq_len: int = ModeloptField(
default=1024,
description=("The training sequence length."),
ge=1,
)
🤖 Prompt for AI Agents
In modelopt/torch/speculative/config.py around lines 98 to 101, the
eagle_training_seq_len field lacks validation allowing 0/negative or
unreasonably large values; add constraints so it's an integer >= 1 and capped to
a reasonable upper bound (e.g. <= 65536 or 16384) to prevent accidental misuse.
Update the ModeloptField/Field call to include numeric constraints (e.g. gt/ ge
and le/lt or equivalent parameters) or add an explicit validator that checks
type int, verifies value >= 1 and <= chosen_max, and raises a clear ValueError
if violated.

@h-guo18 h-guo18 force-pushed the hg/eagle3-crossattn branch from 138c806 to c9bf27a Compare September 23, 2025 20:52
@h-guo18 h-guo18 force-pushed the hg/eagle3-crossattn branch from 47cddea to 94cbb2a Compare September 23, 2025 20:56
Signed-off-by: h-guo18 <[email protected]>
Copy link

codecov bot commented Sep 23, 2025

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 73.46%. Comparing base (26c203a) to head (71c657e).
⚠️ Report is 12 commits behind head on main.

Additional details and impacted files
@@           Coverage Diff           @@
##             main     #350   +/-   ##
=======================================
  Coverage   73.46%   73.46%           
=======================================
  Files         172      172           
  Lines       17640    17640           
=======================================
+ Hits        12959    12960    +1     
+ Misses       4681     4680    -1     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
modelopt/torch/speculative/plugins/transformers.py (1)

899-916: Mirror the flex BlockMask change in pseudo generation.

pseudo_speculative_generate also passes a dense 4‑D mask; switch to BlockMask when flex is available to avoid runtime errors.

Add after computing eagle_attention_mask:

if _FLEX_ATTENTION_AVAILABLE:
    seq_length = eagle_input_hidden_states.size(1)
    eagle_attention_mask = self._get_ttt_attention_mask(seq_length, -1)
🧹 Nitpick comments (1)
modelopt/torch/speculative/plugins/transformers.py (1)

454-454: Optional: make num_ttt_steps configurable (default 3).

Read from config to keep flexibility and avoid hardcoding.

Already covered in the cache fix above with getattr(self.eagle_config, "num_ttt_steps", 3).

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 138c806 and 71c657e.

📒 Files selected for processing (4)
  • examples/speculative_decoding/eagle_utils.py (5 hunks)
  • examples/speculative_decoding/main.py (1 hunks)
  • modelopt/torch/speculative/plugins/transformers.py (11 hunks)
  • tests/unit/torch/speculative/plugins/test_hf_speculative.py (0 hunks)
💤 Files with no reviewable changes (1)
  • tests/unit/torch/speculative/plugins/test_hf_speculative.py
🚧 Files skipped from review as they are similar to previous changes (2)
  • examples/speculative_decoding/eagle_utils.py
  • examples/speculative_decoding/main.py
🧰 Additional context used
🧬 Code graph analysis (1)
modelopt/torch/speculative/plugins/transformers.py (1)
modelopt/torch/speculative/plugins/megatron_eagle.py (1)
  • _eagle_forward (1156-1179)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (4)
  • GitHub Check: linux
  • GitHub Check: wait-checks / wait
  • GitHub Check: wait-checks / wait
  • GitHub Check: code-quality
🔇 Additional comments (7)
modelopt/torch/speculative/plugins/transformers.py (7)

314-323: HF compatibility: confirm past_key_value arg name across versions.

Recent HF uses past_key_value; older ones differ. Verify the target HF version and gate if needed.

If you need to support older HF too, consider a kwarg alias:

-            layer_outputs = decoder_layer(
+            layer_outputs = decoder_layer(
                 hidden_states,
                 attention_mask=attention_mask,
                 position_ids=position_ids,
-                past_key_value=past_key_values,
+                past_key_value=past_key_values,  # alias for newer HF
+                # past_key_values=past_key_values,  # uncomment if targeting older HF that requires this name
                 output_attentions=output_attentions,
                 use_cache=use_cache,
                 position_embeddings=position_embeddings,
             )

820-826: LGTM on surfacing train_acc.

Returning per‑step train accuracy is useful; leaving it empty during eval is fine.

Confirm consumers tolerate train_acc=[] at eval.


850-850: Loss/accuracy computation looks correct.

Soft target CE and masked aggregation are implemented as expected.


39-39: Version‑gate flex_attention import and provide a clean fallback.

Unguarded import will crash on older PyTorch. Gate the import and expose a module‑level flag.

-from torch.nn.attention.flex_attention import BlockMask, create_block_mask
+try:
+    from torch.nn.attention.flex_attention import BlockMask, create_block_mask
+    _FLEX_ATTENTION_AVAILABLE = True
+except Exception:
+    BlockMask = None
+    create_block_mask = None
+    _FLEX_ATTENTION_AVAILABLE = False

187-188: Don’t force flex_attention; fall back to SDPA when unavailable.

This will explode in environments without flex attention.

-        # Use flex attention for efficient TTT
-        config._attn_implementation = "flex_attention"
+        # Prefer flex attention for efficient TTT when available
+        config._attn_implementation = (
+            "flex_attention" if _FLEX_ATTENTION_AVAILABLE else "sdpa"
+        )

454-464: Block mask cache keyed only by ttt_step → wrong when seq_length varies.

If seq_length changes (even accidentally), a mismatched mask will be reused. Cache by seq_length as well.

-        self.num_ttt_steps = 3  # NOTE: (hg) hardcoded for now. Might add to config later.
-        self._cached_attn_blk_masks = []
+        self.num_ttt_steps = getattr(self.eagle_config, "num_ttt_steps", 3)
+        # Cache: {seq_len: {ttt_step: BlockMask}}
+        self._blk_mask_cache = {}
 
-    def _get_ttt_attention_mask(self, seq_length, ttt_step):
-        # compile and cached flex attention masks in first call
-        if ttt_step >= len(self._cached_attn_blk_masks):
-            self._cached_attn_blk_masks.append(self._compile_ttt_block_mask(seq_length, ttt_step))
-
-        # return cached flex attention mask
-        return self._cached_attn_blk_masks[ttt_step]
+    def _get_ttt_attention_mask(self, seq_length, ttt_step):
+        L = int(seq_length)
+        cache_for_len = self._blk_mask_cache.setdefault(L, {})
+        if ttt_step not in cache_for_len:
+            cache_for_len[ttt_step] = self._compile_ttt_block_mask(L, ttt_step)
+        return cache_for_len[ttt_step]

716-718: Guard against converting None eagle_cache to DynamicCache.

DynamicCache.from_legacy_cache(None) can error; add a None guard.

-        if not isinstance(eagle_cache, Cache):
-            eagle_cache = DynamicCache.from_legacy_cache(eagle_cache)
+        if eagle_cache is not None and not isinstance(eagle_cache, Cache):
+            eagle_cache = DynamicCache.from_legacy_cache(eagle_cache)

Comment on lines +542 to 576
def _compile_ttt_block_mask(self, seq_length, ttt_step) -> BlockMask:
"""Compile TTT attention_masks with symbolic masks and return a BlockMask object for flex attention."""
if ttt_step == 0:

# Expand attn_mask
zero_mask = torch.ones_like(attention_mask_0).bool()
mask_2_1 = attention_mask_0.clone().detach()
mask_2_1[:, :, :, :-1] = mask_2_1[:, :, :, 1:]
mask_2_2 = torch.ones_like(attention_mask_0).bool()
for i in range(1, seq_length - 1):
mask_2_2[:, :, i, i] = False
cat_attention_mask = torch.cat(
(
torch.cat((attention_mask_0, zero_mask), dim=-1),
torch.cat((mask_2_1, mask_2_2), dim=-1),
),
dim=-2,
)
cat_attention_mask = cat_attention_mask.masked_fill(cat_attention_mask == 1, dtypemin)

# Concat position_ids
cat_position_ids = torch.cat((position_ids_0, position_ids_0), dim=-1)

elif eagle_generated_hs.shape[1] == seq_length * 2:
cat_input_ids = torch.cat((input_ids_0, input_ids_0, input_ids_0), dim=-1)
cat_eagle_input_hidden_states = torch.cat(
(
eagle_input_hidden_states_0,
torch.zeros(
(b, 1, h),
dtype=eagle_input_hidden_states_0.dtype,
device=eagle_input_hidden_states_0.device,
),
eagle_generated_hs[:, :-1, :],
),
dim=1,
)
zero_mask = torch.ones_like(attention_mask_0).bool()
mask_2_1 = attention_mask_0.clone().detach()
mask_2_1[:, :, :, :-1] = mask_2_1[:, :, :, 1:]
mask_2_2 = torch.ones_like(attention_mask_0).bool()
for i in range(1, seq_length - 1):
mask_2_2[:, :, i, i] = False

mask_3_1 = mask_2_1.clone().detach()
mask_3_1[:, :, :, :-1] = mask_3_1[:, :, :, 1:]
mask_3_2 = mask_2_2.clone().detach()
mask_3_2[:, :, :, :-1] = mask_3_2[:, :, :, 1:]
mask_3_2[:, :, 1, 0] = True
mask_3_3 = mask_2_2.clone().detach()
mask_3_3[:, :, 1, 1] = True
cat_attention_mask = torch.cat(
(
torch.cat((attention_mask_0, zero_mask, zero_mask), dim=-1),
torch.cat((mask_2_1, mask_2_2, zero_mask), dim=-1),
torch.cat((mask_3_1, mask_3_2, mask_3_3), dim=-1),
),
dim=-2,
)
def msk(b, h, q_idx, kv_idx):
# symbolic attention mask of shape [seq_len, 2* seq_len] for TTT step 0
return (kv_idx <= (q_idx - 1)) | (kv_idx == q_idx + seq_length)

cat_attention_mask = cat_attention_mask.masked_fill(cat_attention_mask == 1, dtypemin)
cat_position_ids = torch.cat((position_ids_0, position_ids_0, position_ids_0), dim=-1)

elif eagle_generated_hs.shape[1] == seq_length * 3:
cat_input_ids = torch.cat((input_ids_0, input_ids_0, input_ids_0, input_ids_0), dim=-1)
cat_eagle_input_hidden_states = torch.cat(
(
eagle_input_hidden_states_0,
torch.zeros(
(b, 1, h),
dtype=eagle_input_hidden_states_0.dtype,
device=eagle_input_hidden_states_0.device,
),
eagle_generated_hs[:, :-1, :],
),
dim=1,
)
zero_mask = torch.ones_like(attention_mask_0).bool()
mask_2_1 = attention_mask_0.clone().detach()
mask_2_1[:, :, :, :-1] = mask_2_1[:, :, :, 1:]
mask_2_2 = torch.ones_like(attention_mask_0).bool()
for i in range(1, seq_length - 1):
mask_2_2[:, :, i, i] = False

mask_3_1 = mask_2_1.clone().detach()
mask_3_1[:, :, :, :-1] = mask_3_1[:, :, :, 1:]
mask_3_2 = mask_2_2.clone().detach()
mask_3_2[:, :, :, :-1] = mask_3_2[:, :, :, 1:]
mask_3_2[:, :, 1, 0] = True
mask_3_3 = mask_2_2.clone().detach()
mask_3_3[:, :, 1, 1] = True

mask_4_1 = mask_3_1.clone().detach()
mask_4_1[:, :, :, :-1] = mask_4_1[:, :, :, 1:]
mask_4_2 = mask_3_2.clone().detach()
mask_4_2[:, :, :, :-1] = mask_4_2[:, :, :, 1:]
mask_4_2[:, :, 2, 0] = True
mask_4_3 = mask_3_3.clone().detach()
mask_4_3[:, :, :, :-1] = mask_4_3[:, :, :, 1:]
mask_4_3[:, :, 2, 1] = True
mask_4_4 = mask_3_3.clone().detach()
mask_4_4[:, :, 2, 2] = True

cat_attention_mask = torch.cat(
(
torch.cat((attention_mask_0, zero_mask, zero_mask, zero_mask), dim=-1),
torch.cat((mask_2_1, mask_2_2, zero_mask, zero_mask), dim=-1),
torch.cat((mask_3_1, mask_3_2, mask_3_3, zero_mask), dim=-1),
torch.cat((mask_4_1, mask_4_2, mask_4_3, mask_4_4), dim=-1),
),
dim=-2,
)
cat_attention_mask = cat_attention_mask.masked_fill(cat_attention_mask == 1, dtypemin)
cat_position_ids = torch.cat(
(position_ids_0, position_ids_0, position_ids_0, position_ids_0), dim=-1
)
return create_block_mask(msk, B=None, H=None, Q_LEN=seq_length, KV_LEN=seq_length * 2)
elif ttt_step == 1:

else:
raise ValueError(
f"EAGLE generated hidden states shape {eagle_generated_hs.shape} is not supported"
)
def msk(b, h, q_idx, kv_idx):
# attention mask of shape [seq_len, 3* seq_len] for TTT step 1
return (
(kv_idx <= (q_idx - 2))
| ((kv_idx == q_idx + seq_length - 1) & (kv_idx >= seq_length))
| ((kv_idx == q_idx + 2 * seq_length) & (kv_idx >= seq_length * 2))
)

return create_block_mask(msk, B=None, H=None, Q_LEN=seq_length, KV_LEN=seq_length * 3)
elif ttt_step == 2:

return cat_eagle_input_hidden_states, cat_input_ids, cat_attention_mask, cat_position_ids
def msk(b, h, q_idx, kv_idx):
# attention mask of shape [seq_len, 4* seq_len] for TTT step 2
return (
(kv_idx <= (q_idx - 3))
| ((kv_idx == q_idx + seq_length - 2) & (kv_idx >= seq_length))
| ((kv_idx == q_idx + 2 * seq_length - 1) & (kv_idx >= seq_length * 2))
| ((kv_idx == q_idx + 3 * seq_length) & (kv_idx >= seq_length * 3))
)

return create_block_mask(msk, B=None, H=None, Q_LEN=seq_length, KV_LEN=seq_length * 4)
else:
raise ValueError(f"EAGLE TTT step {ttt_step} is not supported")

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Provide a causal BlockMask for the base pass and document semantics.

The first EAGLE pass uses a dense 4D mask, which won’t be consumed by flex attention. Add a base causal mask variant and reuse it when flex is enabled.

     def _compile_ttt_block_mask(self, seq_length, ttt_step) -> BlockMask:
         """Compile TTT attention_masks with symbolic masks and return a BlockMask object for flex attention."""
+        # ttt_step == -1 is the base pass: standard causal self-attention (allow attending to <= current position).
+        if ttt_step == -1:
+            def msk(b, h, q_idx, kv_idx):
+                # [seq_len, seq_len]: standard causal mask (k <= q)
+                return kv_idx <= q_idx
+            return create_block_mask(msk, B=None, H=None, Q_LEN=seq_length, KV_LEN=seq_length)
         if ttt_step == 0:
             def msk(b, h, q_idx, kv_idx):
                 # symbolic attention mask of shape [seq_len, 2* seq_len] for TTT step 0
                 return (kv_idx <= (q_idx - 1)) | (kv_idx == q_idx + seq_length)
             return create_block_mask(msk, B=None, H=None, Q_LEN=seq_length, KV_LEN=seq_length * 2)
         elif ttt_step == 1:
             def msk(b, h, q_idx, kv_idx):
                 # attention mask of shape [seq_len, 3* seq_len] for TTT step 1
                 return (
                     (kv_idx <= (q_idx - 2))
                     | ((kv_idx == q_idx + seq_length - 1) & (kv_idx >= seq_length))
                     | ((kv_idx == q_idx + 2 * seq_length) & (kv_idx >= seq_length * 2))
                 )
             return create_block_mask(msk, B=None, H=None, Q_LEN=seq_length, KV_LEN=seq_length * 3)
         elif ttt_step == 2:
             def msk(b, h, q_idx, kv_idx):
                 # attention mask of shape [seq_len, 4* seq_len] for TTT step 2
                 return (
                     (kv_idx <= (q_idx - 3))
                     | ((kv_idx == q_idx + seq_length - 2) & (kv_idx >= seq_length))
                     | ((kv_idx == q_idx + 2 * seq_length - 1) & (kv_idx >= seq_length * 2))
                     | ((kv_idx == q_idx + 3 * seq_length) & (kv_idx >= seq_length * 3))
                 )
             return create_block_mask(msk, B=None, H=None, Q_LEN=seq_length, KV_LEN=seq_length * 4)
         else:
             raise ValueError(f"EAGLE TTT step {ttt_step} is not supported")
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def _compile_ttt_block_mask(self, seq_length, ttt_step) -> BlockMask:
"""Compile TTT attention_masks with symbolic masks and return a BlockMask object for flex attention."""
if ttt_step == 0:
# Expand attn_mask
zero_mask = torch.ones_like(attention_mask_0).bool()
mask_2_1 = attention_mask_0.clone().detach()
mask_2_1[:, :, :, :-1] = mask_2_1[:, :, :, 1:]
mask_2_2 = torch.ones_like(attention_mask_0).bool()
for i in range(1, seq_length - 1):
mask_2_2[:, :, i, i] = False
cat_attention_mask = torch.cat(
(
torch.cat((attention_mask_0, zero_mask), dim=-1),
torch.cat((mask_2_1, mask_2_2), dim=-1),
),
dim=-2,
)
cat_attention_mask = cat_attention_mask.masked_fill(cat_attention_mask == 1, dtypemin)
# Concat position_ids
cat_position_ids = torch.cat((position_ids_0, position_ids_0), dim=-1)
elif eagle_generated_hs.shape[1] == seq_length * 2:
cat_input_ids = torch.cat((input_ids_0, input_ids_0, input_ids_0), dim=-1)
cat_eagle_input_hidden_states = torch.cat(
(
eagle_input_hidden_states_0,
torch.zeros(
(b, 1, h),
dtype=eagle_input_hidden_states_0.dtype,
device=eagle_input_hidden_states_0.device,
),
eagle_generated_hs[:, :-1, :],
),
dim=1,
)
zero_mask = torch.ones_like(attention_mask_0).bool()
mask_2_1 = attention_mask_0.clone().detach()
mask_2_1[:, :, :, :-1] = mask_2_1[:, :, :, 1:]
mask_2_2 = torch.ones_like(attention_mask_0).bool()
for i in range(1, seq_length - 1):
mask_2_2[:, :, i, i] = False
mask_3_1 = mask_2_1.clone().detach()
mask_3_1[:, :, :, :-1] = mask_3_1[:, :, :, 1:]
mask_3_2 = mask_2_2.clone().detach()
mask_3_2[:, :, :, :-1] = mask_3_2[:, :, :, 1:]
mask_3_2[:, :, 1, 0] = True
mask_3_3 = mask_2_2.clone().detach()
mask_3_3[:, :, 1, 1] = True
cat_attention_mask = torch.cat(
(
torch.cat((attention_mask_0, zero_mask, zero_mask), dim=-1),
torch.cat((mask_2_1, mask_2_2, zero_mask), dim=-1),
torch.cat((mask_3_1, mask_3_2, mask_3_3), dim=-1),
),
dim=-2,
)
def msk(b, h, q_idx, kv_idx):
# symbolic attention mask of shape [seq_len, 2* seq_len] for TTT step 0
return (kv_idx <= (q_idx - 1)) | (kv_idx == q_idx + seq_length)
cat_attention_mask = cat_attention_mask.masked_fill(cat_attention_mask == 1, dtypemin)
cat_position_ids = torch.cat((position_ids_0, position_ids_0, position_ids_0), dim=-1)
elif eagle_generated_hs.shape[1] == seq_length * 3:
cat_input_ids = torch.cat((input_ids_0, input_ids_0, input_ids_0, input_ids_0), dim=-1)
cat_eagle_input_hidden_states = torch.cat(
(
eagle_input_hidden_states_0,
torch.zeros(
(b, 1, h),
dtype=eagle_input_hidden_states_0.dtype,
device=eagle_input_hidden_states_0.device,
),
eagle_generated_hs[:, :-1, :],
),
dim=1,
)
zero_mask = torch.ones_like(attention_mask_0).bool()
mask_2_1 = attention_mask_0.clone().detach()
mask_2_1[:, :, :, :-1] = mask_2_1[:, :, :, 1:]
mask_2_2 = torch.ones_like(attention_mask_0).bool()
for i in range(1, seq_length - 1):
mask_2_2[:, :, i, i] = False
mask_3_1 = mask_2_1.clone().detach()
mask_3_1[:, :, :, :-1] = mask_3_1[:, :, :, 1:]
mask_3_2 = mask_2_2.clone().detach()
mask_3_2[:, :, :, :-1] = mask_3_2[:, :, :, 1:]
mask_3_2[:, :, 1, 0] = True
mask_3_3 = mask_2_2.clone().detach()
mask_3_3[:, :, 1, 1] = True
mask_4_1 = mask_3_1.clone().detach()
mask_4_1[:, :, :, :-1] = mask_4_1[:, :, :, 1:]
mask_4_2 = mask_3_2.clone().detach()
mask_4_2[:, :, :, :-1] = mask_4_2[:, :, :, 1:]
mask_4_2[:, :, 2, 0] = True
mask_4_3 = mask_3_3.clone().detach()
mask_4_3[:, :, :, :-1] = mask_4_3[:, :, :, 1:]
mask_4_3[:, :, 2, 1] = True
mask_4_4 = mask_3_3.clone().detach()
mask_4_4[:, :, 2, 2] = True
cat_attention_mask = torch.cat(
(
torch.cat((attention_mask_0, zero_mask, zero_mask, zero_mask), dim=-1),
torch.cat((mask_2_1, mask_2_2, zero_mask, zero_mask), dim=-1),
torch.cat((mask_3_1, mask_3_2, mask_3_3, zero_mask), dim=-1),
torch.cat((mask_4_1, mask_4_2, mask_4_3, mask_4_4), dim=-1),
),
dim=-2,
)
cat_attention_mask = cat_attention_mask.masked_fill(cat_attention_mask == 1, dtypemin)
cat_position_ids = torch.cat(
(position_ids_0, position_ids_0, position_ids_0, position_ids_0), dim=-1
)
return create_block_mask(msk, B=None, H=None, Q_LEN=seq_length, KV_LEN=seq_length * 2)
elif ttt_step == 1:
else:
raise ValueError(
f"EAGLE generated hidden states shape {eagle_generated_hs.shape} is not supported"
)
def msk(b, h, q_idx, kv_idx):
# attention mask of shape [seq_len, 3* seq_len] for TTT step 1
return (
(kv_idx <= (q_idx - 2))
| ((kv_idx == q_idx + seq_length - 1) & (kv_idx >= seq_length))
| ((kv_idx == q_idx + 2 * seq_length) & (kv_idx >= seq_length * 2))
)
return create_block_mask(msk, B=None, H=None, Q_LEN=seq_length, KV_LEN=seq_length * 3)
elif ttt_step == 2:
return cat_eagle_input_hidden_states, cat_input_ids, cat_attention_mask, cat_position_ids
def msk(b, h, q_idx, kv_idx):
# attention mask of shape [seq_len, 4* seq_len] for TTT step 2
return (
(kv_idx <= (q_idx - 3))
| ((kv_idx == q_idx + seq_length - 2) & (kv_idx >= seq_length))
| ((kv_idx == q_idx + 2 * seq_length - 1) & (kv_idx >= seq_length * 2))
| ((kv_idx == q_idx + 3 * seq_length) & (kv_idx >= seq_length * 3))
)
return create_block_mask(msk, B=None, H=None, Q_LEN=seq_length, KV_LEN=seq_length * 4)
else:
raise ValueError(f"EAGLE TTT step {ttt_step} is not supported")
def _compile_ttt_block_mask(self, seq_length, ttt_step) -> BlockMask:
"""Compile TTT attention_masks with symbolic masks and return a BlockMask object for flex attention."""
# ttt_step == -1 is the base pass: standard causal self-attention (allow attending to <= current position).
if ttt_step == -1:
def msk(b, h, q_idx, kv_idx):
# [seq_len, seq_len]: standard causal mask (k <= q)
return kv_idx <= q_idx
return create_block_mask(msk, B=None, H=None, Q_LEN=seq_length, KV_LEN=seq_length)
if ttt_step == 0:
def msk(b, h, q_idx, kv_idx):
# symbolic attention mask of shape [seq_len, 2* seq_len] for TTT step 0
return (kv_idx <= (q_idx - 1)) | (kv_idx == q_idx + seq_length)
return create_block_mask(msk, B=None, H=None, Q_LEN=seq_length, KV_LEN=seq_length * 2)
elif ttt_step == 1:
def msk(b, h, q_idx, kv_idx):
# attention mask of shape [seq_len, 3* seq_len] for TTT step 1
return (
(kv_idx <= (q_idx - 2))
| ((kv_idx == q_idx + seq_length - 1) & (kv_idx >= seq_length))
| ((kv_idx == q_idx + 2 * seq_length) & (kv_idx >= seq_length * 2))
)
return create_block_mask(msk, B=None, H=None, Q_LEN=seq_length, KV_LEN=seq_length * 3)
elif ttt_step == 2:
def msk(b, h, q_idx, kv_idx):
# attention mask of shape [seq_len, 4* seq_len] for TTT step 2
return (
(kv_idx <= (q_idx - 3))
| ((kv_idx == q_idx + seq_length - 2) & (kv_idx >= seq_length))
| ((kv_idx == q_idx + 2 * seq_length - 1) & (kv_idx >= seq_length * 2))
| ((kv_idx == q_idx + 3 * seq_length) & (kv_idx >= seq_length * 3))
)
return create_block_mask(msk, B=None, H=None, Q_LEN=seq_length, KV_LEN=seq_length * 4)
else:
raise ValueError(f"EAGLE TTT step {ttt_step} is not supported")

Comment on lines +747 to 754
_, eagle_prenorm_h, eagle_logits, eagle_cache = self._eagle_forward(
eagle_input_hidden_states,
inputs_embeds,
attention_mask_0,
position_ids,
position_embeddings,
eagle_cache,
)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

With flex enabled, the first EAGLE pass must use a BlockMask, not a 4‑D dense mask.

Currently attention_mask_0 is 4‑D; HF flex attention expects BlockMask. Switch to the causal BlockMask for the base pass when flex is available.

-            # Then, we run eagle forward
-            _, eagle_prenorm_h, eagle_logits, eagle_cache = self._eagle_forward(
+            # Then, we run eagle forward
+            if _FLEX_ATTENTION_AVAILABLE:
+                attention_mask_0 = self._get_ttt_attention_mask(seq_length, -1)
+            _, eagle_prenorm_h, eagle_logits, eagle_cache = self._eagle_forward(
                 eagle_input_hidden_states,
                 inputs_embeds,
                 attention_mask_0,
                 position_ids,
                 position_embeddings,
                 eagle_cache,
             )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
_, eagle_prenorm_h, eagle_logits, eagle_cache = self._eagle_forward(
eagle_input_hidden_states,
inputs_embeds,
attention_mask_0,
position_ids,
position_embeddings,
eagle_cache,
)
if _FLEX_ATTENTION_AVAILABLE:
attention_mask_0 = self._get_ttt_attention_mask(seq_length, -1)
_, eagle_prenorm_h, eagle_logits, eagle_cache = self._eagle_forward(
eagle_input_hidden_states,
inputs_embeds,
attention_mask_0,
position_ids,
position_embeddings,
eagle_cache,
)

Comment on lines +768 to 788
for ttt_step in range(self.num_ttt_steps):
eagle_input_hidden_states = torch.cat(
(
torch.zeros(batch_size, 1, dtype=loss_mask.dtype, device=loss_mask.device),
loss_mask[:, 2:],
torch.zeros(
(b, 1, h),
dtype=eagle_input_hidden_states.dtype,
device=eagle_input_hidden_states.device,
),
eagle_prenorm_h[:, :-1, :],
),
dim=1,
),
)
eagle_loss += (
regression_loss_coefficient * regression_loss
+ classification_loss_coefficient * classification_loss
)

# ====Third step of eagle forward====
eagle_input_hidden_states_2, eagle_input_ids_2, attention_mask_2, position_ids_2 = (
self._concat_eagle_inputs(
eagle_input_ids,
eagle_input_hidden_states,
attention_mask_0,
position_ids,
eagle_prenorm_h,
)
)
with torch.no_grad():
inputs_embeds = self.model.embed_tokens(eagle_input_ids_2)
position_embeddings = self.eagle_rotary_emb(eagle_input_hidden_states_2, position_ids_2)
eagle_postnorm_h, eagle_prenorm_h, eagle_logits, eagle_cache = self._eagle_forward(
eagle_input_hidden_states_2,
inputs_embeds,
attention_mask_2,
position_ids_2,
position_embeddings,
)

regression_loss, classification_loss, accuracy_2 = self._eagle_loss(
base_model_hidden_states[:, 1:],
base_model_logits[:, 1:],
eagle_postnorm_h[:, -seq_length:-1, :],
eagle_logits[
:,
-seq_length:-1,
],
torch.cat(
(
torch.zeros(batch_size, 2, dtype=loss_mask.dtype, device=loss_mask.device),
loss_mask[:, 3:],
),
dim=1,
),
)
eagle_loss += (
regression_loss_coefficient * regression_loss
+ classification_loss_coefficient * classification_loss
)

# ====Fourth step of eagle forward====
eagle_input_hidden_states_3, eagle_input_ids_3, attention_mask_3, position_ids_3 = (
self._concat_eagle_inputs(
eagle_input_ids,
attention_mask = self._get_ttt_attention_mask(seq_length, ttt_step)
_, eagle_prenorm_h, eagle_logits, eagle_cache = self._eagle_forward(
eagle_input_hidden_states,
attention_mask_0,
inputs_embeds,
attention_mask,
position_ids,
eagle_prenorm_h,
position_embeddings,
eagle_cache,
)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

🧩 Analysis chain

Recompute RoPE positions for TTT passes when cache grows.

Position embeddings should advance by past_len when you carry over KV from previous passes; reuse of the initial position_embeddings can skew RoPE distances.

-                _, eagle_prenorm_h, eagle_logits, eagle_cache = self._eagle_forward(
+                past_len = eagle_cache.get_seq_length() if eagle_cache is not None else 0
+                pos_ids_ttt = position_ids + past_len
+                pos_emb_ttt = self.eagle_rotary_emb(eagle_input_hidden_states, pos_ids_ttt)
+                _, eagle_prenorm_h, eagle_logits, eagle_cache = self._eagle_forward(
                     eagle_input_hidden_states,
                     inputs_embeds,
-                    attention_mask,
-                    position_ids,
-                    position_embeddings,
+                    attention_mask,
+                    position_ids,
+                    pos_emb_ttt,
                     eagle_cache,
                 )

Please sanity‑check numerics on a short batch to confirm equivalence before/after this change.


🏁 Script executed:

#!/bin/bash
set -euo pipefail

echo "Searching repo for relevant symbols..."
rg -n --hidden --no-ignore -S "def _eagle_forward" || true
rg -n --hidden --no-ignore -S "eagle_rotary_emb" || true
rg -n --hidden --no-ignore -S "get_seq_length" || true
rg -n --hidden --no-ignore -S "position_embeddings" || true
rg -n --hidden --no-ignore -S "position_ids" || true

echo
echo "Showing snippet around the reported location (if file exists):"
file="modelopt/torch/speculative/plugins/transformers.py"
if [ -f "$file" ]; then
  echo "File: $file"
  sed -n '1,240p' "$file" | sed -n '720,820p'
else
  echo "File not found: $file"
fi

echo
echo "Showing definitions/usages for _eagle_forward and eagle_rotary_emb (if found):"
rg -n --hidden --no-ignore -S "def _eagle_forward" || true
rg -n --hidden --no-ignore -S "def eagle_rotary_emb" || true
rg -n --hidden --no-ignore -S "class .*Cache" || true
rg -n --hidden --no-ignore -S "get_seq_length" || true

Length of output: 20763


🏁 Script executed:

#!/bin/bash
set -euo pipefail

file="modelopt/torch/speculative/plugins/transformers.py"
echo "====== Showing file header (1-520) ======"
sed -n '1,520p' "$file" || true

echo
echo "====== Showing lines 600-820 (contains _eagle_forward and TTT loop) ======"
sed -n '600,820p' "$file" || true

echo
echo "====== Showing lines 820-980 (follow-up) ======"
sed -n '820,980p' "$file" || true

echo
echo "====== position_embeddings occurrences in file ======"
rg -n "position_embeddings" "$file" || true

echo
echo "====== eagle_rotary_emb assignment and usages ======"
rg -n "eagle_rotary_emb" "$file" || true

echo
echo "====== Search for def get_seq_length across repo ======"
rg -n --hidden --no-ignore -S "def get_seq_length" || true

echo
echo "====== Search for classes/types that might implement get_seq_length ======"
rg -n --hidden --no-ignore -S "class .*KeyValue|class .*Cache|class .*PastKey" || true

echo
echo "====== Show nearby context where eagle_cache.get_seq_length is used ======"
rg -n --hidden --no-ignore -n "eagle_cache.get_seq_length|past_key_values.get_seq_length" || true

Length of output: 38601


🏁 Script executed:

#!/bin/bash
set -euo pipefail
rg -n "test_speculative_megatron_modules" || true
sed -n '1,360p' tests/gpu/torch/speculative/plugins/test_speculative_megatron_modules.py || true

Length of output: 13544


Recompute RoPE positions for TTT passes when cache grows.

Position embeddings are computed once and reused in the TTT loop; when eagle_cache grows the rotary positions must be advanced by past_len or RoPE distances become incorrect.

Location: modelopt/torch/speculative/plugins/transformers.py — the TTT loop (for ttt_step in range(self.num_ttt_steps):).

-                _, eagle_prenorm_h, eagle_logits, eagle_cache = self._eagle_forward(
+                past_len = eagle_cache.get_seq_length() if eagle_cache is not None else 0
+                pos_ids_ttt = position_ids + past_len
+                pos_emb_ttt = self.eagle_rotary_emb(eagle_input_hidden_states, pos_ids_ttt)
+                _, eagle_prenorm_h, eagle_logits, eagle_cache = self._eagle_forward(
                     eagle_input_hidden_states,
                     inputs_embeds,
-                    attention_mask,
-                    position_ids,
-                    position_embeddings,
+                    attention_mask,
+                    position_ids,
+                    pos_emb_ttt,
                     eagle_cache,
                 )

Please sanity-check numerics on a short batch to confirm equivalence before/after this change.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
for ttt_step in range(self.num_ttt_steps):
eagle_input_hidden_states = torch.cat(
(
torch.zeros(batch_size, 1, dtype=loss_mask.dtype, device=loss_mask.device),
loss_mask[:, 2:],
torch.zeros(
(b, 1, h),
dtype=eagle_input_hidden_states.dtype,
device=eagle_input_hidden_states.device,
),
eagle_prenorm_h[:, :-1, :],
),
dim=1,
),
)
eagle_loss += (
regression_loss_coefficient * regression_loss
+ classification_loss_coefficient * classification_loss
)
# ====Third step of eagle forward====
eagle_input_hidden_states_2, eagle_input_ids_2, attention_mask_2, position_ids_2 = (
self._concat_eagle_inputs(
eagle_input_ids,
eagle_input_hidden_states,
attention_mask_0,
position_ids,
eagle_prenorm_h,
)
)
with torch.no_grad():
inputs_embeds = self.model.embed_tokens(eagle_input_ids_2)
position_embeddings = self.eagle_rotary_emb(eagle_input_hidden_states_2, position_ids_2)
eagle_postnorm_h, eagle_prenorm_h, eagle_logits, eagle_cache = self._eagle_forward(
eagle_input_hidden_states_2,
inputs_embeds,
attention_mask_2,
position_ids_2,
position_embeddings,
)
regression_loss, classification_loss, accuracy_2 = self._eagle_loss(
base_model_hidden_states[:, 1:],
base_model_logits[:, 1:],
eagle_postnorm_h[:, -seq_length:-1, :],
eagle_logits[
:,
-seq_length:-1,
],
torch.cat(
(
torch.zeros(batch_size, 2, dtype=loss_mask.dtype, device=loss_mask.device),
loss_mask[:, 3:],
),
dim=1,
),
)
eagle_loss += (
regression_loss_coefficient * regression_loss
+ classification_loss_coefficient * classification_loss
)
# ====Fourth step of eagle forward====
eagle_input_hidden_states_3, eagle_input_ids_3, attention_mask_3, position_ids_3 = (
self._concat_eagle_inputs(
eagle_input_ids,
attention_mask = self._get_ttt_attention_mask(seq_length, ttt_step)
_, eagle_prenorm_h, eagle_logits, eagle_cache = self._eagle_forward(
eagle_input_hidden_states,
attention_mask_0,
inputs_embeds,
attention_mask,
position_ids,
eagle_prenorm_h,
position_embeddings,
eagle_cache,
)
for ttt_step in range(self.num_ttt_steps):
eagle_input_hidden_states = torch.cat(
(
torch.zeros(
(b, 1, h),
dtype=eagle_input_hidden_states.dtype,
device=eagle_input_hidden_states.device,
),
eagle_prenorm_h[:, :-1, :],
),
dim=1,
)
attention_mask = self._get_ttt_attention_mask(seq_length, ttt_step)
past_len = eagle_cache.get_seq_length() if eagle_cache is not None else 0
pos_ids_ttt = position_ids + past_len
pos_emb_ttt = self.eagle_rotary_emb(eagle_input_hidden_states, pos_ids_ttt)
_, eagle_prenorm_h, eagle_logits, eagle_cache = self._eagle_forward(
eagle_input_hidden_states,
inputs_embeds,
attention_mask,
position_ids,
pos_emb_ttt,
eagle_cache,
)
🤖 Prompt for AI Agents
In modelopt/torch/speculative/plugins/transformers.py around lines 768–788, the
TTT loop reuses position embeddings across iterations even though eagle_cache
grows; recompute/advance RoPE positions each pass by computing the current past
length from eagle_cache (e.g., infer seq length already cached from eagle_cache
structure) and adding that past_len to the base position_ids (or regenerating
position_ids offset by past_len) before recreating position_embeddings, then
pass the updated position_ids/position_embeddings into _eagle_forward; after
implementing this, run a small-batch numeric sanity check to confirm
equivalence.

@h-guo18
Copy link
Contributor Author

h-guo18 commented Sep 23, 2025

/ok to test 71c657e

@h-guo18 h-guo18 merged commit 615f3c0 into main Sep 26, 2025
27 checks passed
@h-guo18 h-guo18 deleted the hg/eagle3-crossattn branch September 26, 2025 21:04
yeyu-nvidia pushed a commit that referenced this pull request Oct 1, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants