-
Notifications
You must be signed in to change notification settings - Fork 160
Refactor set_multi_step_attn_mask for arbitrary step #348
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Refactor set_multi_step_attn_mask for arbitrary step #348
Conversation
Signed-off-by: Ye Yu <[email protected]>
WalkthroughRefactors multi-step attention mask construction from fixed branches to a loop supporting arbitrary steps, removes step limit assertion, and adjusts mask concatenation logic. Also modifies _get_eagle_module_inputs to propagate from attn_mask itself rather than the original attention_mask. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
participant Caller
participant Eagle as MegatronEagle
rect rgba(220,235,245,0.5)
note over Eagle: set_multi_step_attention_mask(step >= 2)
Caller->>Eagle: set_multi_step_attention_mask(attn_mask, step)
loop for iter in range(2, step+1)
note over Eagle: Build zero_mask, mask_0, mask_1
Eagle->>Eagle: Concatenate attn_mask ⊕ zero_mask ⊕ (mask_0, mask_1)
end
Eagle-->>Caller: attn_mask
end
rect rgba(235,245,220,0.5)
note over Eagle: _get_eagle_module_inputs
Caller->>Eagle: _get_eagle_module_inputs(inputs)
Eagle->>Eagle: attn_mask[:,:,:-1,:-1] = attn_mask[:,:,1:,1:]
Eagle-->>Caller: module_inputs
end
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes Poem
Pre-merge checks and finishing touches✅ Passed checks (3 passed)
✨ Finishing touches
🧪 Generate unit tests
📜 Recent review detailsConfiguration used: CodeRabbit UI Review profile: CHILL Plan: Pro 📒 Files selected for processing (1)
🚧 Files skipped from review as they are similar to previous changes (1)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Nitpick comments (1)
modelopt/torch/speculative/plugins/megatron_eagle.py (1)
317-318
: Optional: vectorize the diagonal write to avoid a Python loop.Minor perf tidy-up; reduces Python overhead for large s.
- for i in range(step_idx - 1, s - 1): - mask_1[:, :, i, i] = False + if step_idx - 1 < s - 1: + idx = torch.arange(step_idx - 1, s - 1, device=attn_mask.device) + mask_1[:, :, idx, idx] = False
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
modelopt/torch/speculative/plugins/megatron_eagle.py
(1 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
- GitHub Check: linux
- GitHub Check: wait-checks / wait
- GitHub Check: wait-checks / wait
- GitHub Check: build-docs
- GitHub Check: code-quality
🔇 Additional comments (1)
modelopt/torch/speculative/plugins/megatron_eagle.py (1)
309-327
: Generalization approach looks good.The iterative block-growth builds an (steps)×(steps) mask and removes the hardcoded branches. Nice cleanup.
Please re-run the sandbox regression with:
- CUDA tensors (to confirm device fix), and
- very short sequences (e.g., s ∈ {1,2,3}) with step ∈ {2,3,4} to ensure no out-of-bounds and correct shapes.
s = attn_mask.shape[-1] | ||
zero_mask = torch.ones_like(attn_mask).bool() | ||
mask_2_1 = attn_mask.clone().detach() | ||
mask_2_1[:, :, :, :-1] = mask_2_1[:, :, :, 1:] | ||
mask_2_2 = torch.ones_like(attn_mask).bool() | ||
for i in range(1, s - 1): | ||
mask_2_2[:, :, i, i] = False | ||
|
||
if step == 2: | ||
attn_mask = torch.cat( | ||
( | ||
torch.cat((attn_mask, zero_mask), dim=-1), | ||
torch.cat((mask_2_1, mask_2_2), dim=-1), | ||
), | ||
dim=-2, | ||
) | ||
return attn_mask | ||
|
||
mask_3_1 = mask_2_1.clone().detach() | ||
mask_3_1[:, :, :, :-1] = mask_3_1[:, :, :, 1:] | ||
mask_3_2 = mask_2_2.clone().detach() | ||
mask_3_2[:, :, :, :-1] = mask_3_2[:, :, :, 1:] | ||
mask_3_2[:, :, 1, 0] = True | ||
mask_3_3 = mask_2_2.clone().detach() | ||
mask_3_3[:, :, 1, 1] = True | ||
for iter in range(2, step + 1): | ||
# iter starts from 2nd step | ||
zero_mask = torch.ones(attn_mask.shape[0], attn_mask.shape[1], attn_mask.shape[2], s).bool() | ||
mask_0 = attn_mask.clone().detach()[:, :, -s:, :] | ||
mask_0[:, :, iter - 2] = True | ||
mask_0[:, :, :, :-1] = mask_0[:, :, :, 1:] | ||
mask_1 = torch.ones(attn_mask.shape[0], attn_mask.shape[1], s, s).bool() | ||
for i in range(iter - 1, s - 1): | ||
mask_1[:, :, i, i] = False | ||
|
||
if step == 3: | ||
attn_mask = torch.cat( | ||
( | ||
torch.cat((attn_mask, zero_mask, zero_mask), dim=-1), | ||
torch.cat((mask_2_1, mask_2_2, zero_mask), dim=-1), | ||
torch.cat((mask_3_1, mask_3_2, mask_3_3), dim=-1), | ||
torch.cat((attn_mask, zero_mask), dim=-1), | ||
torch.cat((mask_0, mask_1), dim=-1), | ||
), | ||
dim=-2, | ||
) | ||
return attn_mask | ||
|
||
mask_4_1 = mask_3_1.clone().detach() | ||
mask_4_1[:, :, :, :-1] = mask_4_1[:, :, :, 1:] | ||
mask_4_2 = mask_3_2.clone().detach() | ||
mask_4_2[:, :, :, :-1] = mask_4_2[:, :, :, 1:] | ||
mask_4_2[:, :, 2, 0] = True | ||
mask_4_3 = mask_3_3.clone().detach() | ||
mask_4_3[:, :, :, :-1] = mask_4_3[:, :, :, 1:] | ||
mask_4_3[:, :, 2, 1] = True | ||
mask_4_4 = mask_3_3.clone().detach() | ||
mask_4_4[:, :, 2, 2] = True | ||
|
||
attn_mask = torch.cat( | ||
( | ||
torch.cat((attn_mask, zero_mask, zero_mask, zero_mask), dim=-1), | ||
torch.cat((mask_2_1, mask_2_2, zero_mask, zero_mask), dim=-1), | ||
torch.cat((mask_3_1, mask_3_2, mask_3_3, zero_mask), dim=-1), | ||
torch.cat((mask_4_1, mask_4_2, mask_4_3, mask_4_4), dim=-1), | ||
), | ||
dim=-2, | ||
) | ||
return attn_mask |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix device mismatch and short‑sequence OOB in generalized mask builder.
- zero_mask and mask_1 are created on CPU by default (no device specified). Concatenating them with a CUDA attn_mask will crash at runtime.
- For very short sequences (s = attn_mask.shape[-1]) and larger step, mask_0[:, :, iter-2] will index out of bounds when iter-2 >= s.
Apply this diff to pin allocations to the input mask’s device/dtype and to cap step defensively for short sequences. Also avoid cloning the whole tensor before slicing and rename iter → step_idx for clarity.
@@
- s = attn_mask.shape[-1]
- for iter in range(2, step + 1):
- # iter starts from 2nd step
- zero_mask = torch.ones(attn_mask.shape[0], attn_mask.shape[1], attn_mask.shape[2], s).bool()
- mask_0 = attn_mask.clone().detach()[:, :, -s:, :]
- mask_0[:, :, iter - 2] = True
- mask_0[:, :, :, :-1] = mask_0[:, :, :, 1:]
- mask_1 = torch.ones(attn_mask.shape[0], attn_mask.shape[1], s, s).bool()
- for i in range(iter - 1, s - 1):
- mask_1[:, :, i, i] = False
+ s = attn_mask.shape[-1]
+ # Bound step to avoid indexing past base sequence length on short sequences.
+ max_step = min(int(step), s + 1)
+ if max_step != step:
+ warnings.warn(
+ f"set_multi_step_attention_mask: capping step from {step} to {max_step} "
+ f"for base sequence length s={s}."
+ )
+ for step_idx in range(2, max_step + 1):
+ # step_idx starts from 2nd step
+ zero_mask = attn_mask.new_ones(attn_mask.shape[0], attn_mask.shape[1], attn_mask.shape[2], s)
+ mask_0 = attn_mask[:, :, -s:, :].clone()
+ row_idx = step_idx - 2
+ if row_idx < s:
+ mask_0[:, :, row_idx] = True
+ mask_0[:, :, :, :-1] = mask_0[:, :, :, 1:]
+ mask_1 = attn_mask.new_ones(attn_mask.shape[0], attn_mask.shape[1], s, s)
+ for i in range(step_idx - 1, s - 1):
+ mask_1[:, :, i, i] = False
@@
- attn_mask = torch.cat(
+ attn_mask = torch.cat(
(
- torch.cat((attn_mask, zero_mask), dim=-1),
- torch.cat((mask_0, mask_1), dim=-1),
+ torch.cat((attn_mask, zero_mask), dim=-1),
+ torch.cat((mask_0, mask_1), dim=-1),
),
dim=-2,
)
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
s = attn_mask.shape[-1] | |
zero_mask = torch.ones_like(attn_mask).bool() | |
mask_2_1 = attn_mask.clone().detach() | |
mask_2_1[:, :, :, :-1] = mask_2_1[:, :, :, 1:] | |
mask_2_2 = torch.ones_like(attn_mask).bool() | |
for i in range(1, s - 1): | |
mask_2_2[:, :, i, i] = False | |
if step == 2: | |
attn_mask = torch.cat( | |
( | |
torch.cat((attn_mask, zero_mask), dim=-1), | |
torch.cat((mask_2_1, mask_2_2), dim=-1), | |
), | |
dim=-2, | |
) | |
return attn_mask | |
mask_3_1 = mask_2_1.clone().detach() | |
mask_3_1[:, :, :, :-1] = mask_3_1[:, :, :, 1:] | |
mask_3_2 = mask_2_2.clone().detach() | |
mask_3_2[:, :, :, :-1] = mask_3_2[:, :, :, 1:] | |
mask_3_2[:, :, 1, 0] = True | |
mask_3_3 = mask_2_2.clone().detach() | |
mask_3_3[:, :, 1, 1] = True | |
for iter in range(2, step + 1): | |
# iter starts from 2nd step | |
zero_mask = torch.ones(attn_mask.shape[0], attn_mask.shape[1], attn_mask.shape[2], s).bool() | |
mask_0 = attn_mask.clone().detach()[:, :, -s:, :] | |
mask_0[:, :, iter - 2] = True | |
mask_0[:, :, :, :-1] = mask_0[:, :, :, 1:] | |
mask_1 = torch.ones(attn_mask.shape[0], attn_mask.shape[1], s, s).bool() | |
for i in range(iter - 1, s - 1): | |
mask_1[:, :, i, i] = False | |
if step == 3: | |
attn_mask = torch.cat( | |
( | |
torch.cat((attn_mask, zero_mask, zero_mask), dim=-1), | |
torch.cat((mask_2_1, mask_2_2, zero_mask), dim=-1), | |
torch.cat((mask_3_1, mask_3_2, mask_3_3), dim=-1), | |
torch.cat((attn_mask, zero_mask), dim=-1), | |
torch.cat((mask_0, mask_1), dim=-1), | |
), | |
dim=-2, | |
) | |
return attn_mask | |
mask_4_1 = mask_3_1.clone().detach() | |
mask_4_1[:, :, :, :-1] = mask_4_1[:, :, :, 1:] | |
mask_4_2 = mask_3_2.clone().detach() | |
mask_4_2[:, :, :, :-1] = mask_4_2[:, :, :, 1:] | |
mask_4_2[:, :, 2, 0] = True | |
mask_4_3 = mask_3_3.clone().detach() | |
mask_4_3[:, :, :, :-1] = mask_4_3[:, :, :, 1:] | |
mask_4_3[:, :, 2, 1] = True | |
mask_4_4 = mask_3_3.clone().detach() | |
mask_4_4[:, :, 2, 2] = True | |
attn_mask = torch.cat( | |
( | |
torch.cat((attn_mask, zero_mask, zero_mask, zero_mask), dim=-1), | |
torch.cat((mask_2_1, mask_2_2, zero_mask, zero_mask), dim=-1), | |
torch.cat((mask_3_1, mask_3_2, mask_3_3, zero_mask), dim=-1), | |
torch.cat((mask_4_1, mask_4_2, mask_4_3, mask_4_4), dim=-1), | |
), | |
dim=-2, | |
) | |
return attn_mask | |
s = attn_mask.shape[-1] | |
# Bound step to avoid indexing past base sequence length on short sequences. | |
max_step = min(int(step), s + 1) | |
if max_step != step: | |
warnings.warn( | |
f"set_multi_step_attention_mask: capping step from {step} to {max_step} " | |
f"for base sequence length s={s}." | |
) | |
for step_idx in range(2, max_step + 1): | |
# step_idx starts from 2nd step | |
zero_mask = attn_mask.new_ones(attn_mask.shape[0], attn_mask.shape[1], attn_mask.shape[2], s) | |
mask_0 = attn_mask[:, :, -s:, :].clone() | |
row_idx = step_idx - 2 | |
if row_idx < s: | |
mask_0[:, :, row_idx] = True | |
mask_0[:, :, :, :-1] = mask_0[:, :, :, 1:] | |
mask_1 = attn_mask.new_ones(attn_mask.shape[0], attn_mask.shape[1], s, s) | |
for i in range(step_idx - 1, s - 1): | |
mask_1[:, :, i, i] = False | |
attn_mask = torch.cat( | |
( | |
torch.cat((attn_mask, zero_mask), dim=-1), | |
torch.cat((mask_0, mask_1), dim=-1), | |
), | |
dim=-2, | |
) | |
return attn_mask |
🤖 Prompt for AI Agents
In modelopt/torch/speculative/plugins/megatron_eagle.py around lines 309 to 327,
the generalized mask builder creates zero_mask and mask_1 on CPU and may index
out of bounds for very short sequences; change zero_mask and mask_1 allocations
to use the input attn_mask's device and dtype (e.g., device=attn_mask.device,
dtype=attn_mask.dtype), defensively cap the loop upper bound to min(step, s) (or
compute max_steps = min(step, s)) to avoid mask_0[:, :, step_idx-2] OOB, avoid
cloning the whole attn_mask before slicing (use attn_mask[..., -s:, :] or
appropriate slicing) and rename the loop variable iter to step_idx for clarity;
ensure mask_0 and mask_1 are created with matching shapes and types before
concatenation so all tensors share device/dtype.
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## main #348 +/- ##
=======================================
Coverage 73.83% 73.84%
=======================================
Files 172 172
Lines 17453 17453
=======================================
+ Hits 12887 12888 +1
+ Misses 4566 4565 -1 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
Signed-off-by: Ye Yu <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
♻️ Duplicate comments (1)
modelopt/torch/speculative/plugins/megatron_eagle.py (1)
309-327
: Fix device/dtype mismatch and short‑sequence OOB; don’t shadow built‑initer
.
- zero_mask/mask_1 are allocated on CPU; concatenating with CUDA attn_mask will crash.
- For short sequences,
mask_0[:, :, iter-2]
can index OOB whenstep > s+1
.- Unnecessary full clone before slicing; and
iter
shadows the Python builtin.Apply this diff:
s = attn_mask.shape[-1] -for iter in range(2, step + 1): - # iter starts from 2nd step - zero_mask = torch.ones(attn_mask.shape[0], attn_mask.shape[1], attn_mask.shape[2], s).bool() - mask_0 = attn_mask.clone().detach()[:, :, -s:, :] - mask_0[:, :, iter - 2] = True - mask_0[:, :, :, :-1] = mask_0[:, :, :, 1:] - mask_1 = torch.ones(attn_mask.shape[0], attn_mask.shape[1], s, s).bool() - for i in range(iter - 1, s - 1): - mask_1[:, :, i, i] = False +max_step = min(int(step), s + 1) +if max_step != step: + warnings.warn( + f"set_multi_step_attention_mask: capping step from {step} to {max_step} for base seq len s={s}." + ) +for step_idx in range(2, max_step + 1): + # step_idx starts from 2nd step + zero_mask = attn_mask.new_ones(attn_mask.shape[0], attn_mask.shape[1], attn_mask.shape[2], s) + mask_0 = attn_mask[:, :, -s:, :].clone() + row_idx = step_idx - 2 + if row_idx < s: + mask_0[:, :, row_idx] = True + mask_0[:, :, :, :-1] = mask_0[:, :, :, 1:] + mask_1 = attn_mask.new_ones(attn_mask.shape[0], attn_mask.shape[1], s, s) + for i in range(step_idx - 1, s - 1): + mask_1[:, :, i, i] = FalseOptional follow-up: vectorize the diagonal update later for perf.
🧹 Nitpick comments (1)
modelopt/torch/speculative/plugins/megatron_eagle.py (1)
845-849
: Avoid in‑place aliasing; use the original mask as RHS.Assigning from
attn_mask
into an overlapping slice is unnecessary and can be error‑prone. Read fromattention_mask
(the source you just cloned) for clarity and safety.- attn_mask[:, :, :-1, :-1] = attn_mask[:, :, 1:, 1:] + attn_mask[:, :, :-1, :-1] = attention_mask[:, :, 1:, 1:]
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
modelopt/torch/speculative/plugins/megatron_eagle.py
(2 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
- GitHub Check: linux
- GitHub Check: wait-checks / wait
- GitHub Check: wait-checks / wait
- GitHub Check: code-quality
- GitHub Check: build-docs
Signed-off-by: Ye Yu <[email protected]>
What does this PR do?
Type of change: refactor
Overview:
Our current set_multi_step_attn_mask function hardcode attention mask for step=2,3,4 and do not support arbitrary step. This PR generalize it to support arbitrary step > 1.
Usage
# Add a code snippet demonstrating how to use this
Testing
Tested attention mask locally and pass the sandbox regression test.
Before your PR is "Ready for review"
Additional Information
Summary by CodeRabbit
New Features
Performance
Reliability
Compatibility