We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 9e398c8 commit 57f9143Copy full SHA for 57f9143
modelopt/torch/speculative/plugins/megatron_eagle.py
@@ -309,11 +309,13 @@ def set_multi_step_attention_mask(attn_mask, step):
309
s = attn_mask.shape[-1]
310
for iter in range(2, step + 1):
311
# iter starts from 2nd step
312
- zero_mask = torch.ones(attn_mask.shape[0], attn_mask.shape[1], attn_mask.shape[2], s).bool()
+ zero_mask = attn_mask.new_ones(
313
+ attn_mask.shape[0], attn_mask.shape[1], attn_mask.shape[2], s
314
+ ).bool()
315
mask_0 = attn_mask.clone().detach()[:, :, -s:, :]
- mask_0[:, :, iter - 2] = True
316
+ mask_0[:, :, iter - 2, :] = True
317
mask_0[:, :, :, :-1] = mask_0[:, :, :, 1:]
- mask_1 = torch.ones(attn_mask.shape[0], attn_mask.shape[1], s, s).bool()
318
+ mask_1 = attn_mask.new_ones(attn_mask.shape[0], attn_mask.shape[1], s, s).bool()
319
for i in range(iter - 1, s - 1):
320
mask_1[:, :, i, i] = False
321
0 commit comments