Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 14 additions & 52 deletions modelopt/torch/speculative/plugins/megatron_eagle.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,65 +305,27 @@ def set_multi_step_attention_mask(attn_mask, step):
=======================================================================================================================
""" # noqa: E501
assert step > 1, "step should be larger than 1 in multi-step attention mask."
assert step <= 4, "Currently only a step of 4 or smaller is supported!"

s = attn_mask.shape[-1]
zero_mask = torch.ones_like(attn_mask).bool()
mask_2_1 = attn_mask.clone().detach()
mask_2_1[:, :, :, :-1] = mask_2_1[:, :, :, 1:]
mask_2_2 = torch.ones_like(attn_mask).bool()
for i in range(1, s - 1):
mask_2_2[:, :, i, i] = False

if step == 2:
attn_mask = torch.cat(
(
torch.cat((attn_mask, zero_mask), dim=-1),
torch.cat((mask_2_1, mask_2_2), dim=-1),
),
dim=-2,
)
return attn_mask

mask_3_1 = mask_2_1.clone().detach()
mask_3_1[:, :, :, :-1] = mask_3_1[:, :, :, 1:]
mask_3_2 = mask_2_2.clone().detach()
mask_3_2[:, :, :, :-1] = mask_3_2[:, :, :, 1:]
mask_3_2[:, :, 1, 0] = True
mask_3_3 = mask_2_2.clone().detach()
mask_3_3[:, :, 1, 1] = True
for iter in range(2, step + 1):
# iter starts from 2nd step
zero_mask = attn_mask.new_ones(
attn_mask.shape[0], attn_mask.shape[1], attn_mask.shape[2], s
).bool()
mask_0 = attn_mask.clone().detach()[:, :, -s:, :]
mask_0[:, :, iter - 2, :] = True
mask_0[:, :, :, :-1] = mask_0[:, :, :, 1:]
mask_1 = attn_mask.new_ones(attn_mask.shape[0], attn_mask.shape[1], s, s).bool()
for i in range(iter - 1, s - 1):
mask_1[:, :, i, i] = False

if step == 3:
attn_mask = torch.cat(
(
torch.cat((attn_mask, zero_mask, zero_mask), dim=-1),
torch.cat((mask_2_1, mask_2_2, zero_mask), dim=-1),
torch.cat((mask_3_1, mask_3_2, mask_3_3), dim=-1),
torch.cat((attn_mask, zero_mask), dim=-1),
torch.cat((mask_0, mask_1), dim=-1),
),
dim=-2,
)
return attn_mask

mask_4_1 = mask_3_1.clone().detach()
mask_4_1[:, :, :, :-1] = mask_4_1[:, :, :, 1:]
mask_4_2 = mask_3_2.clone().detach()
mask_4_2[:, :, :, :-1] = mask_4_2[:, :, :, 1:]
mask_4_2[:, :, 2, 0] = True
mask_4_3 = mask_3_3.clone().detach()
mask_4_3[:, :, :, :-1] = mask_4_3[:, :, :, 1:]
mask_4_3[:, :, 2, 1] = True
mask_4_4 = mask_3_3.clone().detach()
mask_4_4[:, :, 2, 2] = True

attn_mask = torch.cat(
(
torch.cat((attn_mask, zero_mask, zero_mask, zero_mask), dim=-1),
torch.cat((mask_2_1, mask_2_2, zero_mask, zero_mask), dim=-1),
torch.cat((mask_3_1, mask_3_2, mask_3_3, zero_mask), dim=-1),
torch.cat((mask_4_1, mask_4_2, mask_4_3, mask_4_4), dim=-1),
),
dim=-2,
)
return attn_mask


Expand Down Expand Up @@ -883,7 +845,7 @@ def _get_eagle_module_inputs(
rotary_pos_emb = self.eagle_module.rotary_pos_emb(padded_input_ids.shape[-1])

attn_mask = attention_mask.clone().detach()
attn_mask[:, :, :-1, :-1] = attention_mask[:, :, 1:, 1:]
attn_mask[:, :, :-1, :-1] = attn_mask[:, :, 1:, 1:]
attn_mask[:, :, -1, :] = True
attn_mask[:, :, :, -1] = True

Expand Down
Loading