diff --git a/modelopt/torch/speculative/plugins/megatron_eagle.py b/modelopt/torch/speculative/plugins/megatron_eagle.py index 2a0e63a3..a0978fd4 100644 --- a/modelopt/torch/speculative/plugins/megatron_eagle.py +++ b/modelopt/torch/speculative/plugins/megatron_eagle.py @@ -305,65 +305,27 @@ def set_multi_step_attention_mask(attn_mask, step): ======================================================================================================================= """ # noqa: E501 assert step > 1, "step should be larger than 1 in multi-step attention mask." - assert step <= 4, "Currently only a step of 4 or smaller is supported!" s = attn_mask.shape[-1] - zero_mask = torch.ones_like(attn_mask).bool() - mask_2_1 = attn_mask.clone().detach() - mask_2_1[:, :, :, :-1] = mask_2_1[:, :, :, 1:] - mask_2_2 = torch.ones_like(attn_mask).bool() - for i in range(1, s - 1): - mask_2_2[:, :, i, i] = False - - if step == 2: - attn_mask = torch.cat( - ( - torch.cat((attn_mask, zero_mask), dim=-1), - torch.cat((mask_2_1, mask_2_2), dim=-1), - ), - dim=-2, - ) - return attn_mask - - mask_3_1 = mask_2_1.clone().detach() - mask_3_1[:, :, :, :-1] = mask_3_1[:, :, :, 1:] - mask_3_2 = mask_2_2.clone().detach() - mask_3_2[:, :, :, :-1] = mask_3_2[:, :, :, 1:] - mask_3_2[:, :, 1, 0] = True - mask_3_3 = mask_2_2.clone().detach() - mask_3_3[:, :, 1, 1] = True + for iter in range(2, step + 1): + # iter starts from 2nd step + zero_mask = attn_mask.new_ones( + attn_mask.shape[0], attn_mask.shape[1], attn_mask.shape[2], s + ).bool() + mask_0 = attn_mask.clone().detach()[:, :, -s:, :] + mask_0[:, :, iter - 2, :] = True + mask_0[:, :, :, :-1] = mask_0[:, :, :, 1:] + mask_1 = attn_mask.new_ones(attn_mask.shape[0], attn_mask.shape[1], s, s).bool() + for i in range(iter - 1, s - 1): + mask_1[:, :, i, i] = False - if step == 3: attn_mask = torch.cat( ( - torch.cat((attn_mask, zero_mask, zero_mask), dim=-1), - torch.cat((mask_2_1, mask_2_2, zero_mask), dim=-1), - torch.cat((mask_3_1, mask_3_2, mask_3_3), dim=-1), + torch.cat((attn_mask, zero_mask), dim=-1), + torch.cat((mask_0, mask_1), dim=-1), ), dim=-2, ) - return attn_mask - - mask_4_1 = mask_3_1.clone().detach() - mask_4_1[:, :, :, :-1] = mask_4_1[:, :, :, 1:] - mask_4_2 = mask_3_2.clone().detach() - mask_4_2[:, :, :, :-1] = mask_4_2[:, :, :, 1:] - mask_4_2[:, :, 2, 0] = True - mask_4_3 = mask_3_3.clone().detach() - mask_4_3[:, :, :, :-1] = mask_4_3[:, :, :, 1:] - mask_4_3[:, :, 2, 1] = True - mask_4_4 = mask_3_3.clone().detach() - mask_4_4[:, :, 2, 2] = True - - attn_mask = torch.cat( - ( - torch.cat((attn_mask, zero_mask, zero_mask, zero_mask), dim=-1), - torch.cat((mask_2_1, mask_2_2, zero_mask, zero_mask), dim=-1), - torch.cat((mask_3_1, mask_3_2, mask_3_3, zero_mask), dim=-1), - torch.cat((mask_4_1, mask_4_2, mask_4_3, mask_4_4), dim=-1), - ), - dim=-2, - ) return attn_mask @@ -883,7 +845,7 @@ def _get_eagle_module_inputs( rotary_pos_emb = self.eagle_module.rotary_pos_emb(padded_input_ids.shape[-1]) attn_mask = attention_mask.clone().detach() - attn_mask[:, :, :-1, :-1] = attention_mask[:, :, 1:, 1:] + attn_mask[:, :, :-1, :-1] = attn_mask[:, :, 1:, 1:] attn_mask[:, :, -1, :] = True attn_mask[:, :, :, -1] = True