Skip to content

Commit 57f9143

Browse files
committed
make new mask the same dtype and device as attn_mask
Signed-off-by: Ye Yu <[email protected]>
1 parent 9e398c8 commit 57f9143

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

modelopt/torch/speculative/plugins/megatron_eagle.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -309,11 +309,13 @@ def set_multi_step_attention_mask(attn_mask, step):
309309
s = attn_mask.shape[-1]
310310
for iter in range(2, step + 1):
311311
# iter starts from 2nd step
312-
zero_mask = torch.ones(attn_mask.shape[0], attn_mask.shape[1], attn_mask.shape[2], s).bool()
312+
zero_mask = attn_mask.new_ones(
313+
attn_mask.shape[0], attn_mask.shape[1], attn_mask.shape[2], s
314+
).bool()
313315
mask_0 = attn_mask.clone().detach()[:, :, -s:, :]
314-
mask_0[:, :, iter - 2] = True
316+
mask_0[:, :, iter - 2, :] = True
315317
mask_0[:, :, :, :-1] = mask_0[:, :, :, 1:]
316-
mask_1 = torch.ones(attn_mask.shape[0], attn_mask.shape[1], s, s).bool()
318+
mask_1 = attn_mask.new_ones(attn_mask.shape[0], attn_mask.shape[1], s, s).bool()
317319
for i in range(iter - 1, s - 1):
318320
mask_1[:, :, i, i] = False
319321

0 commit comments

Comments
 (0)