Skip to content

Commit 8d81046

Browse files
committed
refactor_multi_step_attn_mask_for_arbitrary_step
Signed-off-by: Ye Yu <[email protected]>
1 parent 5a3fd29 commit 8d81046

File tree

1 file changed

+11
-51
lines changed

1 file changed

+11
-51
lines changed

modelopt/torch/speculative/plugins/megatron_eagle.py

Lines changed: 11 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -305,65 +305,25 @@ def set_multi_step_attention_mask(attn_mask, step):
305305
=======================================================================================================================
306306
""" # noqa: E501
307307
assert step > 1, "step should be larger than 1 in multi-step attention mask."
308-
assert step <= 4, "Currently only a step of 4 or smaller is supported!"
309308

310309
s = attn_mask.shape[-1]
311-
zero_mask = torch.ones_like(attn_mask).bool()
312-
mask_2_1 = attn_mask.clone().detach()
313-
mask_2_1[:, :, :, :-1] = mask_2_1[:, :, :, 1:]
314-
mask_2_2 = torch.ones_like(attn_mask).bool()
315-
for i in range(1, s - 1):
316-
mask_2_2[:, :, i, i] = False
317-
318-
if step == 2:
319-
attn_mask = torch.cat(
320-
(
321-
torch.cat((attn_mask, zero_mask), dim=-1),
322-
torch.cat((mask_2_1, mask_2_2), dim=-1),
323-
),
324-
dim=-2,
325-
)
326-
return attn_mask
327-
328-
mask_3_1 = mask_2_1.clone().detach()
329-
mask_3_1[:, :, :, :-1] = mask_3_1[:, :, :, 1:]
330-
mask_3_2 = mask_2_2.clone().detach()
331-
mask_3_2[:, :, :, :-1] = mask_3_2[:, :, :, 1:]
332-
mask_3_2[:, :, 1, 0] = True
333-
mask_3_3 = mask_2_2.clone().detach()
334-
mask_3_3[:, :, 1, 1] = True
310+
for iter in range(2, step + 1):
311+
# iter starts from 2nd step
312+
zero_mask = torch.ones(attn_mask.shape[0], attn_mask.shape[1], attn_mask.shape[2], s).bool()
313+
mask_0 = attn_mask.clone().detach()[:, :, -s:, :]
314+
mask_0[:, :, iter - 2] = True
315+
mask_0[:, :, :, :-1] = mask_0[:, :, :, 1:]
316+
mask_1 = torch.ones(attn_mask.shape[0], attn_mask.shape[1], s, s).bool()
317+
for i in range(iter - 1, s - 1):
318+
mask_1[:, :, i, i] = False
335319

336-
if step == 3:
337320
attn_mask = torch.cat(
338321
(
339-
torch.cat((attn_mask, zero_mask, zero_mask), dim=-1),
340-
torch.cat((mask_2_1, mask_2_2, zero_mask), dim=-1),
341-
torch.cat((mask_3_1, mask_3_2, mask_3_3), dim=-1),
322+
torch.cat((attn_mask, zero_mask), dim=-1),
323+
torch.cat((mask_0, mask_1), dim=-1),
342324
),
343325
dim=-2,
344326
)
345-
return attn_mask
346-
347-
mask_4_1 = mask_3_1.clone().detach()
348-
mask_4_1[:, :, :, :-1] = mask_4_1[:, :, :, 1:]
349-
mask_4_2 = mask_3_2.clone().detach()
350-
mask_4_2[:, :, :, :-1] = mask_4_2[:, :, :, 1:]
351-
mask_4_2[:, :, 2, 0] = True
352-
mask_4_3 = mask_3_3.clone().detach()
353-
mask_4_3[:, :, :, :-1] = mask_4_3[:, :, :, 1:]
354-
mask_4_3[:, :, 2, 1] = True
355-
mask_4_4 = mask_3_3.clone().detach()
356-
mask_4_4[:, :, 2, 2] = True
357-
358-
attn_mask = torch.cat(
359-
(
360-
torch.cat((attn_mask, zero_mask, zero_mask, zero_mask), dim=-1),
361-
torch.cat((mask_2_1, mask_2_2, zero_mask, zero_mask), dim=-1),
362-
torch.cat((mask_3_1, mask_3_2, mask_3_3, zero_mask), dim=-1),
363-
torch.cat((mask_4_1, mask_4_2, mask_4_3, mask_4_4), dim=-1),
364-
),
365-
dim=-2,
366-
)
367327
return attn_mask
368328

369329

0 commit comments

Comments
 (0)