Skip to content

Commit ade4ef3

Browse files
committed
Broadcasts attention bias over query dimension
Updates forward/backward equivalence benchmarks to create attention bias with a singleton query dimension so it broadcasts across queries. Aligns shapes with kernel expectations during cached decoding, reduces memory footprint, and prevents shape mismatches across CUDA, Triton, and Flex paths.
1 parent 2951e24 commit ade4ef3

File tree

2 files changed

+5
-5
lines changed

2 files changed

+5
-5
lines changed

benchmarks/backward_equivalence.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -597,7 +597,7 @@ def test_cuda_backward_equivalence(accuracy_threshold=0.95):
597597
device=device, dtype=dtype, requires_grad=True
598598
)
599599
attn_bias = torch.randn(
600-
batch_size, num_kv_heads, query_len, key_len,
600+
batch_size, num_kv_heads, 1, key_len,
601601
device=device, dtype=torch.bfloat16
602602
)
603603
cache_position = torch.arange(key_len - query_len, key_len, device=device)
@@ -831,7 +831,7 @@ def test_triton_backward_equivalence(accuracy_threshold=0.95):
831831
device=device, dtype=dtype, requires_grad=True
832832
)
833833
attn_bias = torch.randn(
834-
batch_size, num_kv_heads, query_len, key_len,
834+
batch_size, num_kv_heads, 1, key_len,
835835
device=device, dtype=torch.bfloat16
836836
)
837837
cache_position = torch.arange(key_len - query_len, key_len, device=device)

benchmarks/forward_equivalence.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -570,7 +570,7 @@ def test_cuda_forward_equivalence(accuracy_threshold=0.95):
570570
device=device, dtype=torch.bfloat16
571571
)
572572
attn_bias = torch.randn(
573-
batch_size, num_kv_heads, query_len, key_len,
573+
batch_size, num_kv_heads, 1, key_len,
574574
device=device, dtype=torch.bfloat16
575575
)
576576
cache_position = torch.arange(key_len - query_len, key_len, device=device)
@@ -758,7 +758,7 @@ def test_triton_forward_equivalence(accuracy_threshold=0.95):
758758
device=device, dtype=torch.bfloat16
759759
)
760760
attn_bias = torch.randn(
761-
batch_size, num_kv_heads, query_len, key_len,
761+
batch_size, num_kv_heads, 1, key_len,
762762
device=device, dtype=torch.bfloat16
763763
)
764764
cache_position = torch.arange(key_len - query_len, key_len, device=device)
@@ -963,7 +963,7 @@ def test_flex_forward_equivalence(accuracy_threshold=0.95):
963963
device=device, dtype=torch.bfloat16
964964
)
965965
attn_bias = torch.randn(
966-
batch_size, num_kv_heads, query_len, key_len,
966+
batch_size, num_kv_heads, 1, key_len,
967967
device=device, dtype=torch.bfloat16
968968
)
969969
cache_position = torch.arange(key_len - query_len, key_len, device=device)

0 commit comments

Comments
 (0)