Skip to content

Commit b98f2a9

Browse files
LoserCheemsCopilot
andauthored
Apply suggestions from code review
Co-authored-by: Copilot <[email protected]>
1 parent 2aba697 commit b98f2a9

File tree

4 files changed

+4
-4
lines changed

4 files changed

+4
-4
lines changed

benchmarks/backward_equivalence.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def prepare_dynamic_mask(
8989
valid_topk = topk_values != min_dtype
9090
attn_mask = torch.zeros_like(attn_bias, dtype=torch.bool, device=attn_bias.device)
9191
attn_mask = attn_mask.scatter(-1, topk_indices, valid_topk)
92-
attn_bias = attn_bias.masked_fill(attn_mask == False, min_dtype)
92+
attn_bias = attn_bias.masked_fill(~attn_mask, min_dtype)
9393
else:
9494
attn_mask = torch.ones_like(attn_bias, dtype=torch.bool, device=attn_bias.device)
9595
return attn_bias, attn_mask

benchmarks/backward_performance.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ def prepare_dynamic_mask(
111111
valid_topk = topk_values != min_dtype
112112
attn_mask = torch.zeros_like(attn_bias, dtype=torch.bool, device=attn_bias.device)
113113
attn_mask = attn_mask.scatter(-1, topk_indices, valid_topk)
114-
attn_bias = attn_bias.masked_fill(attn_mask == False, min_dtype)
114+
attn_bias = attn_bias.masked_fill(~attn_mask, min_dtype)
115115
else:
116116
attn_mask = torch.ones_like(attn_bias, dtype=torch.bool, device=attn_bias.device)
117117
return attn_bias, attn_mask

benchmarks/forward_equivalence.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def prepare_dynamic_mask(
8989
valid_topk = topk_values != min_dtype
9090
attn_mask = torch.zeros_like(attn_bias, dtype=torch.bool, device=attn_bias.device)
9191
attn_mask = attn_mask.scatter(-1, topk_indices, valid_topk)
92-
attn_bias = attn_bias.masked_fill(attn_mask == False, min_dtype)
92+
attn_bias = attn_bias.masked_fill(~attn_mask, min_dtype)
9393
else:
9494
attn_mask = torch.ones_like(attn_bias, dtype=torch.bool, device=attn_bias.device)
9595
return attn_bias, attn_mask

benchmarks/forward_performance.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ def prepare_dynamic_mask(
111111
valid_topk = topk_values != min_dtype
112112
attn_mask = torch.zeros_like(attn_bias, dtype=torch.bool, device=attn_bias.device)
113113
attn_mask = attn_mask.scatter(-1, topk_indices, valid_topk)
114-
attn_bias = attn_bias.masked_fill(attn_mask == False, min_dtype)
114+
attn_bias = attn_bias.masked_fill(~attn_mask, min_dtype)
115115
else:
116116
attn_mask = torch.ones_like(attn_bias, dtype=torch.bool, device=attn_bias.device)
117117
return attn_bias, attn_mask

0 commit comments

Comments
 (0)