Skip to content
Open
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 12 additions & 52 deletions modelopt/torch/speculative/plugins/megatron_eagle.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,65 +305,25 @@ def set_multi_step_attention_mask(attn_mask, step):
=======================================================================================================================
""" # noqa: E501
assert step > 1, "step should be larger than 1 in multi-step attention mask."
assert step <= 4, "Currently only a step of 4 or smaller is supported!"

s = attn_mask.shape[-1]
zero_mask = torch.ones_like(attn_mask).bool()
mask_2_1 = attn_mask.clone().detach()
mask_2_1[:, :, :, :-1] = mask_2_1[:, :, :, 1:]
mask_2_2 = torch.ones_like(attn_mask).bool()
for i in range(1, s - 1):
mask_2_2[:, :, i, i] = False

if step == 2:
attn_mask = torch.cat(
(
torch.cat((attn_mask, zero_mask), dim=-1),
torch.cat((mask_2_1, mask_2_2), dim=-1),
),
dim=-2,
)
return attn_mask

mask_3_1 = mask_2_1.clone().detach()
mask_3_1[:, :, :, :-1] = mask_3_1[:, :, :, 1:]
mask_3_2 = mask_2_2.clone().detach()
mask_3_2[:, :, :, :-1] = mask_3_2[:, :, :, 1:]
mask_3_2[:, :, 1, 0] = True
mask_3_3 = mask_2_2.clone().detach()
mask_3_3[:, :, 1, 1] = True
for iter in range(2, step + 1):
# iter starts from 2nd step
zero_mask = torch.ones(attn_mask.shape[0], attn_mask.shape[1], attn_mask.shape[2], s).bool()
mask_0 = attn_mask.clone().detach()[:, :, -s:, :]
mask_0[:, :, iter - 2] = True
mask_0[:, :, :, :-1] = mask_0[:, :, :, 1:]
mask_1 = torch.ones(attn_mask.shape[0], attn_mask.shape[1], s, s).bool()
for i in range(iter - 1, s - 1):
mask_1[:, :, i, i] = False

if step == 3:
attn_mask = torch.cat(
(
torch.cat((attn_mask, zero_mask, zero_mask), dim=-1),
torch.cat((mask_2_1, mask_2_2, zero_mask), dim=-1),
torch.cat((mask_3_1, mask_3_2, mask_3_3), dim=-1),
torch.cat((attn_mask, zero_mask), dim=-1),
torch.cat((mask_0, mask_1), dim=-1),
),
dim=-2,
)
return attn_mask

mask_4_1 = mask_3_1.clone().detach()
mask_4_1[:, :, :, :-1] = mask_4_1[:, :, :, 1:]
mask_4_2 = mask_3_2.clone().detach()
mask_4_2[:, :, :, :-1] = mask_4_2[:, :, :, 1:]
mask_4_2[:, :, 2, 0] = True
mask_4_3 = mask_3_3.clone().detach()
mask_4_3[:, :, :, :-1] = mask_4_3[:, :, :, 1:]
mask_4_3[:, :, 2, 1] = True
mask_4_4 = mask_3_3.clone().detach()
mask_4_4[:, :, 2, 2] = True

attn_mask = torch.cat(
(
torch.cat((attn_mask, zero_mask, zero_mask, zero_mask), dim=-1),
torch.cat((mask_2_1, mask_2_2, zero_mask, zero_mask), dim=-1),
torch.cat((mask_3_1, mask_3_2, mask_3_3, zero_mask), dim=-1),
torch.cat((mask_4_1, mask_4_2, mask_4_3, mask_4_4), dim=-1),
),
dim=-2,
)
return attn_mask
Comment on lines 309 to 329
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Fix device mismatch and short‑sequence OOB in generalized mask builder.

  • zero_mask and mask_1 are created on CPU by default (no device specified). Concatenating them with a CUDA attn_mask will crash at runtime.
  • For very short sequences (s = attn_mask.shape[-1]) and larger step, mask_0[:, :, iter-2] will index out of bounds when iter-2 >= s.

Apply this diff to pin allocations to the input mask’s device/dtype and to cap step defensively for short sequences. Also avoid cloning the whole tensor before slicing and rename iter → step_idx for clarity.

@@
-    s = attn_mask.shape[-1]
-    for iter in range(2, step + 1):
-        # iter starts from 2nd step
-        zero_mask = torch.ones(attn_mask.shape[0], attn_mask.shape[1], attn_mask.shape[2], s).bool()
-        mask_0 = attn_mask.clone().detach()[:, :, -s:, :]
-        mask_0[:, :, iter - 2] = True
-        mask_0[:, :, :, :-1] = mask_0[:, :, :, 1:]
-        mask_1 = torch.ones(attn_mask.shape[0], attn_mask.shape[1], s, s).bool()
-        for i in range(iter - 1, s - 1):
-            mask_1[:, :, i, i] = False
+    s = attn_mask.shape[-1]
+    # Bound step to avoid indexing past base sequence length on short sequences.
+    max_step = min(int(step), s + 1)
+    if max_step != step:
+        warnings.warn(
+            f"set_multi_step_attention_mask: capping step from {step} to {max_step} "
+            f"for base sequence length s={s}."
+        )
+    for step_idx in range(2, max_step + 1):
+        # step_idx starts from 2nd step
+        zero_mask = attn_mask.new_ones(attn_mask.shape[0], attn_mask.shape[1], attn_mask.shape[2], s)
+        mask_0 = attn_mask[:, :, -s:, :].clone()
+        row_idx = step_idx - 2
+        if row_idx < s:
+            mask_0[:, :, row_idx] = True
+        mask_0[:, :, :, :-1] = mask_0[:, :, :, 1:]
+        mask_1 = attn_mask.new_ones(attn_mask.shape[0], attn_mask.shape[1], s, s)
+        for i in range(step_idx - 1, s - 1):
+            mask_1[:, :, i, i] = False
@@
-        attn_mask = torch.cat(
+        attn_mask = torch.cat(
             (
-                torch.cat((attn_mask, zero_mask), dim=-1),
-                torch.cat((mask_0, mask_1), dim=-1),
+                torch.cat((attn_mask, zero_mask), dim=-1),
+                torch.cat((mask_0, mask_1), dim=-1),
             ),
             dim=-2,
         )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
s = attn_mask.shape[-1]
zero_mask = torch.ones_like(attn_mask).bool()
mask_2_1 = attn_mask.clone().detach()
mask_2_1[:, :, :, :-1] = mask_2_1[:, :, :, 1:]
mask_2_2 = torch.ones_like(attn_mask).bool()
for i in range(1, s - 1):
mask_2_2[:, :, i, i] = False
if step == 2:
attn_mask = torch.cat(
(
torch.cat((attn_mask, zero_mask), dim=-1),
torch.cat((mask_2_1, mask_2_2), dim=-1),
),
dim=-2,
)
return attn_mask
mask_3_1 = mask_2_1.clone().detach()
mask_3_1[:, :, :, :-1] = mask_3_1[:, :, :, 1:]
mask_3_2 = mask_2_2.clone().detach()
mask_3_2[:, :, :, :-1] = mask_3_2[:, :, :, 1:]
mask_3_2[:, :, 1, 0] = True
mask_3_3 = mask_2_2.clone().detach()
mask_3_3[:, :, 1, 1] = True
for iter in range(2, step + 1):
# iter starts from 2nd step
zero_mask = torch.ones(attn_mask.shape[0], attn_mask.shape[1], attn_mask.shape[2], s).bool()
mask_0 = attn_mask.clone().detach()[:, :, -s:, :]
mask_0[:, :, iter - 2] = True
mask_0[:, :, :, :-1] = mask_0[:, :, :, 1:]
mask_1 = torch.ones(attn_mask.shape[0], attn_mask.shape[1], s, s).bool()
for i in range(iter - 1, s - 1):
mask_1[:, :, i, i] = False
if step == 3:
attn_mask = torch.cat(
(
torch.cat((attn_mask, zero_mask, zero_mask), dim=-1),
torch.cat((mask_2_1, mask_2_2, zero_mask), dim=-1),
torch.cat((mask_3_1, mask_3_2, mask_3_3), dim=-1),
torch.cat((attn_mask, zero_mask), dim=-1),
torch.cat((mask_0, mask_1), dim=-1),
),
dim=-2,
)
return attn_mask
mask_4_1 = mask_3_1.clone().detach()
mask_4_1[:, :, :, :-1] = mask_4_1[:, :, :, 1:]
mask_4_2 = mask_3_2.clone().detach()
mask_4_2[:, :, :, :-1] = mask_4_2[:, :, :, 1:]
mask_4_2[:, :, 2, 0] = True
mask_4_3 = mask_3_3.clone().detach()
mask_4_3[:, :, :, :-1] = mask_4_3[:, :, :, 1:]
mask_4_3[:, :, 2, 1] = True
mask_4_4 = mask_3_3.clone().detach()
mask_4_4[:, :, 2, 2] = True
attn_mask = torch.cat(
(
torch.cat((attn_mask, zero_mask, zero_mask, zero_mask), dim=-1),
torch.cat((mask_2_1, mask_2_2, zero_mask, zero_mask), dim=-1),
torch.cat((mask_3_1, mask_3_2, mask_3_3, zero_mask), dim=-1),
torch.cat((mask_4_1, mask_4_2, mask_4_3, mask_4_4), dim=-1),
),
dim=-2,
)
return attn_mask
s = attn_mask.shape[-1]
# Bound step to avoid indexing past base sequence length on short sequences.
max_step = min(int(step), s + 1)
if max_step != step:
warnings.warn(
f"set_multi_step_attention_mask: capping step from {step} to {max_step} "
f"for base sequence length s={s}."
)
for step_idx in range(2, max_step + 1):
# step_idx starts from 2nd step
zero_mask = attn_mask.new_ones(attn_mask.shape[0], attn_mask.shape[1], attn_mask.shape[2], s)
mask_0 = attn_mask[:, :, -s:, :].clone()
row_idx = step_idx - 2
if row_idx < s:
mask_0[:, :, row_idx] = True
mask_0[:, :, :, :-1] = mask_0[:, :, :, 1:]
mask_1 = attn_mask.new_ones(attn_mask.shape[0], attn_mask.shape[1], s, s)
for i in range(step_idx - 1, s - 1):
mask_1[:, :, i, i] = False
attn_mask = torch.cat(
(
torch.cat((attn_mask, zero_mask), dim=-1),
torch.cat((mask_0, mask_1), dim=-1),
),
dim=-2,
)
return attn_mask
🤖 Prompt for AI Agents
In modelopt/torch/speculative/plugins/megatron_eagle.py around lines 309 to 327,
the generalized mask builder creates zero_mask and mask_1 on CPU and may index
out of bounds for very short sequences; change zero_mask and mask_1 allocations
to use the input attn_mask's device and dtype (e.g., device=attn_mask.device,
dtype=attn_mask.dtype), defensively cap the loop upper bound to min(step, s) (or
compute max_steps = min(step, s)) to avoid mask_0[:, :, step_idx-2] OOB, avoid
cloning the whole attn_mask before slicing (use attn_mask[..., -s:, :] or
appropriate slicing) and rename the loop variable iter to step_idx for clarity;
ensure mask_0 and mask_1 are created with matching shapes and types before
concatenation so all tensors share device/dtype.



Expand Down Expand Up @@ -883,7 +843,7 @@ def _get_eagle_module_inputs(
rotary_pos_emb = self.eagle_module.rotary_pos_emb(padded_input_ids.shape[-1])

attn_mask = attention_mask.clone().detach()
attn_mask[:, :, :-1, :-1] = attention_mask[:, :, 1:, 1:]
attn_mask[:, :, :-1, :-1] = attn_mask[:, :, 1:, 1:]
attn_mask[:, :, -1, :] = True
attn_mask[:, :, :, -1] = True

Expand Down
Loading