Skip to content

Commit ab4513a

Browse files
committed
Refactors attention mask to use boolean dtype
Improves type consistency and performance by using torch.bool instead of generic dtype for attention masks. Eliminates unnecessary type conversions and simplifies mask comparison logic by using False instead of 0.0 comparisons.
1 parent 9751d0f commit ab4513a

File tree

4 files changed

+16
-16
lines changed

4 files changed

+16
-16
lines changed

benchmarks/backward_equivalence.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -87,11 +87,11 @@ def prepare_dynamic_mask(
8787
attn_bias, keep_window_size, dim=-1, largest=True, sorted=False
8888
)
8989
valid_topk = topk_values != min_dtype
90-
attn_mask = torch.zeros_like(attn_bias, dtype=dtype, device=attn_bias.device)
91-
attn_mask = attn_mask.scatter(-1, topk_indices, valid_topk.to(dtype))
92-
attn_bias = attn_bias.masked_fill(attn_mask == 0.0, min_dtype)
90+
attn_mask = torch.zeros_like(attn_bias, dtype=torch.bool, device=attn_bias.device)
91+
attn_mask = attn_mask.scatter(-1, topk_indices, valid_topk)
92+
attn_bias = attn_bias.masked_fill(attn_mask == False, min_dtype)
9393
else:
94-
attn_mask = torch.ones_like(attn_bias, dtype=dtype, device=attn_bias.device)
94+
attn_mask = torch.ones_like(attn_bias, dtype=torch.bool, device=attn_bias.device)
9595
return attn_bias, attn_mask
9696

9797

benchmarks/backward_performance.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -109,11 +109,11 @@ def prepare_dynamic_mask(
109109
attn_bias, keep_window_size, dim=-1, largest=True, sorted=False
110110
)
111111
valid_topk = topk_values != min_dtype
112-
attn_mask = torch.zeros_like(attn_bias, dtype=dtype, device=attn_bias.device)
113-
attn_mask = attn_mask.scatter(-1, topk_indices, valid_topk.to(dtype))
114-
attn_bias = attn_bias.masked_fill(attn_mask == 0.0, min_dtype)
112+
attn_mask = torch.zeros_like(attn_bias, dtype=torch.bool, device=attn_bias.device)
113+
attn_mask = attn_mask.scatter(-1, topk_indices, valid_topk)
114+
attn_bias = attn_bias.masked_fill(attn_mask == False, min_dtype)
115115
else:
116-
attn_mask = torch.ones_like(attn_bias, dtype=dtype, device=attn_bias.device)
116+
attn_mask = torch.ones_like(attn_bias, dtype=torch.bool, device=attn_bias.device)
117117
return attn_bias, attn_mask
118118

119119

benchmarks/forward_equivalence.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -87,11 +87,11 @@ def prepare_dynamic_mask(
8787
attn_bias, keep_window_size, dim=-1, largest=True, sorted=False
8888
)
8989
valid_topk = topk_values != min_dtype
90-
attn_mask = torch.zeros_like(attn_bias, dtype=dtype, device=attn_bias.device)
91-
attn_mask = attn_mask.scatter(-1, topk_indices, valid_topk.to(dtype))
92-
attn_bias = attn_bias.masked_fill(attn_mask == 0.0, min_dtype)
90+
attn_mask = torch.zeros_like(attn_bias, dtype=torch.bool, device=attn_bias.device)
91+
attn_mask = attn_mask.scatter(-1, topk_indices, valid_topk)
92+
attn_bias = attn_bias.masked_fill(attn_mask == False, min_dtype)
9393
else:
94-
attn_mask = torch.ones_like(attn_bias, dtype=dtype, device=attn_bias.device)
94+
attn_mask = torch.ones_like(attn_bias, dtype=torch.bool, device=attn_bias.device)
9595
return attn_bias, attn_mask
9696

9797

benchmarks/forward_performance.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -109,11 +109,11 @@ def prepare_dynamic_mask(
109109
attn_bias, keep_window_size, dim=-1, largest=True, sorted=False
110110
)
111111
valid_topk = topk_values != min_dtype
112-
attn_mask = torch.zeros_like(attn_bias, dtype=dtype, device=attn_bias.device)
113-
attn_mask = attn_mask.scatter(-1, topk_indices, valid_topk.to(dtype))
114-
attn_bias = attn_bias.masked_fill(attn_mask == 0.0, min_dtype)
112+
attn_mask = torch.zeros_like(attn_bias, dtype=torch.bool, device=attn_bias.device)
113+
attn_mask = attn_mask.scatter(-1, topk_indices, valid_topk)
114+
attn_bias = attn_bias.masked_fill(attn_mask == False, min_dtype)
115115
else:
116-
attn_mask = torch.ones_like(attn_bias, dtype=dtype, device=attn_bias.device)
116+
attn_mask = torch.ones_like(attn_bias, dtype=torch.bool, device=attn_bias.device)
117117
return attn_bias, attn_mask
118118

119119

0 commit comments

Comments
 (0)