Skip to content

Conversation

yeyu-nvidia
Copy link
Contributor

@yeyu-nvidia yeyu-nvidia commented Sep 19, 2025

What does this PR do?

Type of change: refactor

Overview:
Our current set_multi_step_attn_mask function hardcode attention mask for step=2,3,4 and do not support arbitrary step. This PR generalize it to support arbitrary step > 1.

Usage

# Add a code snippet demonstrating how to use this

Testing

Tested attention mask locally and pass the sandbox regression test.

Before your PR is "Ready for review"

  • Make sure you read and follow Contributor guidelines and your commits are signed.
  • Is this change backward compatible?: Yes/No
  • Did you write any new necessary tests?: Yes/No
  • Did you add or update any necessary documentation?: Yes/No
  • Did you update Changelog?: Yes/No

Additional Information

Summary by CodeRabbit

  • New Features

    • Expanded support for multi-step speculative decoding beyond four steps.
  • Performance

    • More streamlined multi-step attention mask construction may reduce overhead for larger step counts.
  • Reliability

    • Unified, iterative mask building improves consistency and reduces edge-case failures.
  • Compatibility

    • Public interfaces unchanged; existing integrations continue to work without modification.

@yeyu-nvidia yeyu-nvidia requested a review from a team as a code owner September 19, 2025 17:57
Copy link

coderabbitai bot commented Sep 19, 2025

Walkthrough

Refactors multi-step attention mask construction from fixed branches to a loop supporting arbitrary steps, removes step limit assertion, and adjusts mask concatenation logic. Also modifies _get_eagle_module_inputs to propagate from attn_mask itself rather than the original attention_mask.

Changes

Cohort / File(s) Summary of changes
Multi-step attention mask refactor
modelopt/torch/speculative/plugins/megatron_eagle.py
Replaced per-step branches (2–4) with a generalized loop over range(2, step + 1), constructing zero_mask, mask_0, mask_1 each iteration and concatenating accordingly; removed assertion limiting step <= 4; consolidated final concatenations; unified control flow without early returns.
Eagle module inputs mask propagation
modelopt/torch/speculative/plugins/megatron_eagle.py
In _get_eagle_module_inputs, changed slice assignment to derive from evolving attn_mask: attn_mask[:, :, :-1, :-1] = attn_mask[:, :, 1:, 1:] instead of using attention_mask. Public signature of set_multi_step_attention_mask unchanged.

Sequence Diagram(s)

sequenceDiagram
  autonumber
  participant Caller
  participant Eagle as MegatronEagle
  rect rgba(220,235,245,0.5)
    note over Eagle: set_multi_step_attention_mask(step >= 2)
    Caller->>Eagle: set_multi_step_attention_mask(attn_mask, step)
    loop for iter in range(2, step+1)
      note over Eagle: Build zero_mask, mask_0, mask_1
      Eagle->>Eagle: Concatenate attn_mask ⊕ zero_mask ⊕ (mask_0, mask_1)
    end
    Eagle-->>Caller: attn_mask
  end

  rect rgba(235,245,220,0.5)
    note over Eagle: _get_eagle_module_inputs
    Caller->>Eagle: _get_eagle_module_inputs(inputs)
    Eagle->>Eagle: attn_mask[:,:,:-1,:-1] = attn_mask[:,:,1:,1:]
    Eagle-->>Caller: module_inputs
  end
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

Poem

I stitched new masks in looping flight,
No step cap now—just onward night.
I nudge the grid to copy through,
From self to self, a fresher view.
A hop, a refactor—clean and spry,
Carrots aligned, my tensors fly! 🥕🐇

Pre-merge checks and finishing touches

✅ Passed checks (3 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title Check ✅ Passed The title "Refactor set_multi_step_attn_mask for arbitrary step" is a concise, single-sentence summary that accurately reflects the PR's primary change—generalizing the multi-step attention mask implementation to handle arbitrary step values rather than hardcoded cases. It names the affected function (using a common "attn" abbreviation) and the intent (refactor for arbitrary step), so a reviewer scanning history can quickly grasp the main change.
Docstring Coverage ✅ Passed Docstring coverage is 100.00% which is sufficient. The required threshold is 80.00%.
✨ Finishing touches
  • 📝 Generate Docstrings
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch yeyu/refactor_multi_step_attn_mask_for_arbitrary_step

📜 Recent review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 9e398c8 and 57f9143.

📒 Files selected for processing (1)
  • modelopt/torch/speculative/plugins/megatron_eagle.py (2 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • modelopt/torch/speculative/plugins/megatron_eagle.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
  • GitHub Check: linux
  • GitHub Check: wait-checks / wait
  • GitHub Check: wait-checks / wait
  • GitHub Check: code-quality
  • GitHub Check: build-docs

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (1)
modelopt/torch/speculative/plugins/megatron_eagle.py (1)

317-318: Optional: vectorize the diagonal write to avoid a Python loop.

Minor perf tidy-up; reduces Python overhead for large s.

-        for i in range(step_idx - 1, s - 1):
-            mask_1[:, :, i, i] = False
+        if step_idx - 1 < s - 1:
+            idx = torch.arange(step_idx - 1, s - 1, device=attn_mask.device)
+            mask_1[:, :, idx, idx] = False
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 5a3fd29 and 8d81046.

📒 Files selected for processing (1)
  • modelopt/torch/speculative/plugins/megatron_eagle.py (1 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
  • GitHub Check: linux
  • GitHub Check: wait-checks / wait
  • GitHub Check: wait-checks / wait
  • GitHub Check: build-docs
  • GitHub Check: code-quality
🔇 Additional comments (1)
modelopt/torch/speculative/plugins/megatron_eagle.py (1)

309-327: Generalization approach looks good.

The iterative block-growth builds an (steps)×(steps) mask and removes the hardcoded branches. Nice cleanup.

Please re-run the sandbox regression with:

  • CUDA tensors (to confirm device fix), and
  • very short sequences (e.g., s ∈ {1,2,3}) with step ∈ {2,3,4} to ensure no out-of-bounds and correct shapes.

Comment on lines 309 to 327
s = attn_mask.shape[-1]
zero_mask = torch.ones_like(attn_mask).bool()
mask_2_1 = attn_mask.clone().detach()
mask_2_1[:, :, :, :-1] = mask_2_1[:, :, :, 1:]
mask_2_2 = torch.ones_like(attn_mask).bool()
for i in range(1, s - 1):
mask_2_2[:, :, i, i] = False

if step == 2:
attn_mask = torch.cat(
(
torch.cat((attn_mask, zero_mask), dim=-1),
torch.cat((mask_2_1, mask_2_2), dim=-1),
),
dim=-2,
)
return attn_mask

mask_3_1 = mask_2_1.clone().detach()
mask_3_1[:, :, :, :-1] = mask_3_1[:, :, :, 1:]
mask_3_2 = mask_2_2.clone().detach()
mask_3_2[:, :, :, :-1] = mask_3_2[:, :, :, 1:]
mask_3_2[:, :, 1, 0] = True
mask_3_3 = mask_2_2.clone().detach()
mask_3_3[:, :, 1, 1] = True
for iter in range(2, step + 1):
# iter starts from 2nd step
zero_mask = torch.ones(attn_mask.shape[0], attn_mask.shape[1], attn_mask.shape[2], s).bool()
mask_0 = attn_mask.clone().detach()[:, :, -s:, :]
mask_0[:, :, iter - 2] = True
mask_0[:, :, :, :-1] = mask_0[:, :, :, 1:]
mask_1 = torch.ones(attn_mask.shape[0], attn_mask.shape[1], s, s).bool()
for i in range(iter - 1, s - 1):
mask_1[:, :, i, i] = False

if step == 3:
attn_mask = torch.cat(
(
torch.cat((attn_mask, zero_mask, zero_mask), dim=-1),
torch.cat((mask_2_1, mask_2_2, zero_mask), dim=-1),
torch.cat((mask_3_1, mask_3_2, mask_3_3), dim=-1),
torch.cat((attn_mask, zero_mask), dim=-1),
torch.cat((mask_0, mask_1), dim=-1),
),
dim=-2,
)
return attn_mask

mask_4_1 = mask_3_1.clone().detach()
mask_4_1[:, :, :, :-1] = mask_4_1[:, :, :, 1:]
mask_4_2 = mask_3_2.clone().detach()
mask_4_2[:, :, :, :-1] = mask_4_2[:, :, :, 1:]
mask_4_2[:, :, 2, 0] = True
mask_4_3 = mask_3_3.clone().detach()
mask_4_3[:, :, :, :-1] = mask_4_3[:, :, :, 1:]
mask_4_3[:, :, 2, 1] = True
mask_4_4 = mask_3_3.clone().detach()
mask_4_4[:, :, 2, 2] = True

attn_mask = torch.cat(
(
torch.cat((attn_mask, zero_mask, zero_mask, zero_mask), dim=-1),
torch.cat((mask_2_1, mask_2_2, zero_mask, zero_mask), dim=-1),
torch.cat((mask_3_1, mask_3_2, mask_3_3, zero_mask), dim=-1),
torch.cat((mask_4_1, mask_4_2, mask_4_3, mask_4_4), dim=-1),
),
dim=-2,
)
return attn_mask
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Fix device mismatch and short‑sequence OOB in generalized mask builder.

  • zero_mask and mask_1 are created on CPU by default (no device specified). Concatenating them with a CUDA attn_mask will crash at runtime.
  • For very short sequences (s = attn_mask.shape[-1]) and larger step, mask_0[:, :, iter-2] will index out of bounds when iter-2 >= s.

Apply this diff to pin allocations to the input mask’s device/dtype and to cap step defensively for short sequences. Also avoid cloning the whole tensor before slicing and rename iter → step_idx for clarity.

@@
-    s = attn_mask.shape[-1]
-    for iter in range(2, step + 1):
-        # iter starts from 2nd step
-        zero_mask = torch.ones(attn_mask.shape[0], attn_mask.shape[1], attn_mask.shape[2], s).bool()
-        mask_0 = attn_mask.clone().detach()[:, :, -s:, :]
-        mask_0[:, :, iter - 2] = True
-        mask_0[:, :, :, :-1] = mask_0[:, :, :, 1:]
-        mask_1 = torch.ones(attn_mask.shape[0], attn_mask.shape[1], s, s).bool()
-        for i in range(iter - 1, s - 1):
-            mask_1[:, :, i, i] = False
+    s = attn_mask.shape[-1]
+    # Bound step to avoid indexing past base sequence length on short sequences.
+    max_step = min(int(step), s + 1)
+    if max_step != step:
+        warnings.warn(
+            f"set_multi_step_attention_mask: capping step from {step} to {max_step} "
+            f"for base sequence length s={s}."
+        )
+    for step_idx in range(2, max_step + 1):
+        # step_idx starts from 2nd step
+        zero_mask = attn_mask.new_ones(attn_mask.shape[0], attn_mask.shape[1], attn_mask.shape[2], s)
+        mask_0 = attn_mask[:, :, -s:, :].clone()
+        row_idx = step_idx - 2
+        if row_idx < s:
+            mask_0[:, :, row_idx] = True
+        mask_0[:, :, :, :-1] = mask_0[:, :, :, 1:]
+        mask_1 = attn_mask.new_ones(attn_mask.shape[0], attn_mask.shape[1], s, s)
+        for i in range(step_idx - 1, s - 1):
+            mask_1[:, :, i, i] = False
@@
-        attn_mask = torch.cat(
+        attn_mask = torch.cat(
             (
-                torch.cat((attn_mask, zero_mask), dim=-1),
-                torch.cat((mask_0, mask_1), dim=-1),
+                torch.cat((attn_mask, zero_mask), dim=-1),
+                torch.cat((mask_0, mask_1), dim=-1),
             ),
             dim=-2,
         )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
s = attn_mask.shape[-1]
zero_mask = torch.ones_like(attn_mask).bool()
mask_2_1 = attn_mask.clone().detach()
mask_2_1[:, :, :, :-1] = mask_2_1[:, :, :, 1:]
mask_2_2 = torch.ones_like(attn_mask).bool()
for i in range(1, s - 1):
mask_2_2[:, :, i, i] = False
if step == 2:
attn_mask = torch.cat(
(
torch.cat((attn_mask, zero_mask), dim=-1),
torch.cat((mask_2_1, mask_2_2), dim=-1),
),
dim=-2,
)
return attn_mask
mask_3_1 = mask_2_1.clone().detach()
mask_3_1[:, :, :, :-1] = mask_3_1[:, :, :, 1:]
mask_3_2 = mask_2_2.clone().detach()
mask_3_2[:, :, :, :-1] = mask_3_2[:, :, :, 1:]
mask_3_2[:, :, 1, 0] = True
mask_3_3 = mask_2_2.clone().detach()
mask_3_3[:, :, 1, 1] = True
for iter in range(2, step + 1):
# iter starts from 2nd step
zero_mask = torch.ones(attn_mask.shape[0], attn_mask.shape[1], attn_mask.shape[2], s).bool()
mask_0 = attn_mask.clone().detach()[:, :, -s:, :]
mask_0[:, :, iter - 2] = True
mask_0[:, :, :, :-1] = mask_0[:, :, :, 1:]
mask_1 = torch.ones(attn_mask.shape[0], attn_mask.shape[1], s, s).bool()
for i in range(iter - 1, s - 1):
mask_1[:, :, i, i] = False
if step == 3:
attn_mask = torch.cat(
(
torch.cat((attn_mask, zero_mask, zero_mask), dim=-1),
torch.cat((mask_2_1, mask_2_2, zero_mask), dim=-1),
torch.cat((mask_3_1, mask_3_2, mask_3_3), dim=-1),
torch.cat((attn_mask, zero_mask), dim=-1),
torch.cat((mask_0, mask_1), dim=-1),
),
dim=-2,
)
return attn_mask
mask_4_1 = mask_3_1.clone().detach()
mask_4_1[:, :, :, :-1] = mask_4_1[:, :, :, 1:]
mask_4_2 = mask_3_2.clone().detach()
mask_4_2[:, :, :, :-1] = mask_4_2[:, :, :, 1:]
mask_4_2[:, :, 2, 0] = True
mask_4_3 = mask_3_3.clone().detach()
mask_4_3[:, :, :, :-1] = mask_4_3[:, :, :, 1:]
mask_4_3[:, :, 2, 1] = True
mask_4_4 = mask_3_3.clone().detach()
mask_4_4[:, :, 2, 2] = True
attn_mask = torch.cat(
(
torch.cat((attn_mask, zero_mask, zero_mask, zero_mask), dim=-1),
torch.cat((mask_2_1, mask_2_2, zero_mask, zero_mask), dim=-1),
torch.cat((mask_3_1, mask_3_2, mask_3_3, zero_mask), dim=-1),
torch.cat((mask_4_1, mask_4_2, mask_4_3, mask_4_4), dim=-1),
),
dim=-2,
)
return attn_mask
s = attn_mask.shape[-1]
# Bound step to avoid indexing past base sequence length on short sequences.
max_step = min(int(step), s + 1)
if max_step != step:
warnings.warn(
f"set_multi_step_attention_mask: capping step from {step} to {max_step} "
f"for base sequence length s={s}."
)
for step_idx in range(2, max_step + 1):
# step_idx starts from 2nd step
zero_mask = attn_mask.new_ones(attn_mask.shape[0], attn_mask.shape[1], attn_mask.shape[2], s)
mask_0 = attn_mask[:, :, -s:, :].clone()
row_idx = step_idx - 2
if row_idx < s:
mask_0[:, :, row_idx] = True
mask_0[:, :, :, :-1] = mask_0[:, :, :, 1:]
mask_1 = attn_mask.new_ones(attn_mask.shape[0], attn_mask.shape[1], s, s)
for i in range(step_idx - 1, s - 1):
mask_1[:, :, i, i] = False
attn_mask = torch.cat(
(
torch.cat((attn_mask, zero_mask), dim=-1),
torch.cat((mask_0, mask_1), dim=-1),
),
dim=-2,
)
return attn_mask
🤖 Prompt for AI Agents
In modelopt/torch/speculative/plugins/megatron_eagle.py around lines 309 to 327,
the generalized mask builder creates zero_mask and mask_1 on CPU and may index
out of bounds for very short sequences; change zero_mask and mask_1 allocations
to use the input attn_mask's device and dtype (e.g., device=attn_mask.device,
dtype=attn_mask.dtype), defensively cap the loop upper bound to min(step, s) (or
compute max_steps = min(step, s)) to avoid mask_0[:, :, step_idx-2] OOB, avoid
cloning the whole attn_mask before slicing (use attn_mask[..., -s:, :] or
appropriate slicing) and rename the loop variable iter to step_idx for clarity;
ensure mask_0 and mask_1 are created with matching shapes and types before
concatenation so all tensors share device/dtype.

Copy link

codecov bot commented Sep 19, 2025

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 73.84%. Comparing base (4c36abe) to head (57f9143).
⚠️ Report is 2 commits behind head on main.

Additional details and impacted files
@@           Coverage Diff           @@
##             main     #348   +/-   ##
=======================================
  Coverage   73.83%   73.84%           
=======================================
  Files         172      172           
  Lines       17453    17453           
=======================================
+ Hits        12887    12888    +1     
+ Misses       4566     4565    -1     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Signed-off-by: Ye Yu <[email protected]>
Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (1)
modelopt/torch/speculative/plugins/megatron_eagle.py (1)

309-327: Fix device/dtype mismatch and short‑sequence OOB; don’t shadow built‑in iter.

  • zero_mask/mask_1 are allocated on CPU; concatenating with CUDA attn_mask will crash.
  • For short sequences, mask_0[:, :, iter-2] can index OOB when step > s+1.
  • Unnecessary full clone before slicing; and iter shadows the Python builtin.

Apply this diff:

 s = attn_mask.shape[-1]
-for iter in range(2, step + 1):
-    # iter starts from 2nd step
-    zero_mask = torch.ones(attn_mask.shape[0], attn_mask.shape[1], attn_mask.shape[2], s).bool()
-    mask_0 = attn_mask.clone().detach()[:, :, -s:, :]
-    mask_0[:, :, iter - 2] = True
-    mask_0[:, :, :, :-1] = mask_0[:, :, :, 1:]
-    mask_1 = torch.ones(attn_mask.shape[0], attn_mask.shape[1], s, s).bool()
-    for i in range(iter - 1, s - 1):
-        mask_1[:, :, i, i] = False
+max_step = min(int(step), s + 1)
+if max_step != step:
+    warnings.warn(
+        f"set_multi_step_attention_mask: capping step from {step} to {max_step} for base seq len s={s}."
+    )
+for step_idx in range(2, max_step + 1):
+    # step_idx starts from 2nd step
+    zero_mask = attn_mask.new_ones(attn_mask.shape[0], attn_mask.shape[1], attn_mask.shape[2], s)
+    mask_0 = attn_mask[:, :, -s:, :].clone()
+    row_idx = step_idx - 2
+    if row_idx < s:
+        mask_0[:, :, row_idx] = True
+    mask_0[:, :, :, :-1] = mask_0[:, :, :, 1:]
+    mask_1 = attn_mask.new_ones(attn_mask.shape[0], attn_mask.shape[1], s, s)
+    for i in range(step_idx - 1, s - 1):
+        mask_1[:, :, i, i] = False

Optional follow-up: vectorize the diagonal update later for perf.

🧹 Nitpick comments (1)
modelopt/torch/speculative/plugins/megatron_eagle.py (1)

845-849: Avoid in‑place aliasing; use the original mask as RHS.

Assigning from attn_mask into an overlapping slice is unnecessary and can be error‑prone. Read from attention_mask (the source you just cloned) for clarity and safety.

- attn_mask[:, :, :-1, :-1] = attn_mask[:, :, 1:, 1:]
+ attn_mask[:, :, :-1, :-1] = attention_mask[:, :, 1:, 1:]
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 8d81046 and 9e398c8.

📒 Files selected for processing (1)
  • modelopt/torch/speculative/plugins/megatron_eagle.py (2 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
  • GitHub Check: linux
  • GitHub Check: wait-checks / wait
  • GitHub Check: wait-checks / wait
  • GitHub Check: code-quality
  • GitHub Check: build-docs

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant