Skip to content

Commit 1cd83de

Browse files
pytorchbotIsalia20
andauthored
[Flex attention] Fix flex attention head broadcast (pytorch#164368)
[Flex attention] Fix flex attention head broadcast (pytorch#163426) Fixes part of pytorch#163314 In particular bug: **Bug 1: H=None Broadcasting Produces Incorrect Results** This fixes a shape bug when slicing BlockMask on the Q-tile axis with an int (**mask[:, :, i]**). That form of indexing collapses the Q dimension, so kv_num_blocks/kv_indices lose their expected [B, H, Q_tiles, …] shape. Due to them losing shape, even though the mask_mod remains "interpretable", the kernel’s stride math then reads wrong offsets. Due to this we get silent numerical mismatches compared to regular SDPA, especially when single position decoding/H broadcasting. The B=None, H=None works case is accidental: with singleton batch/head the kernel maps to index 0 via `sparse_idx_z = off_zq % 1` and `sparse_idx_hq = off_hq % 1` and with a single Q tile `q_start // SPARSE_Q_MULTIPLE = 0`. The missing Q-tiles stride is multiplied by 0, so the bad offset from the collapsed Q axis doesn’t move the pointer and it happens to read the first tile correctly. Once H > 1 or there are multiple Q tiles, those terms become nonzero and the kernel indexes with wrong strides which causes silent error Pull Request resolved: pytorch#163426 Approved by: https://github.com/drisspg (cherry picked from commit 1a42656) Co-authored-by: Isalia20 <[email protected]>
1 parent 881c2cc commit 1cd83de

File tree

2 files changed

+89
-7
lines changed

2 files changed

+89
-7
lines changed

test/inductor/test_flex_attention.py

Lines changed: 80 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4666,8 +4666,8 @@ def causal_mask(b, h, q, kv):
46664666

46674667
block_mask = create_block_mask(causal_mask, 4, 2, 2048, 2048, device=device)
46684668
self.assertEqual(block_mask.shape, (4, 2, 2048, 2048))
4669-
self.assertEqual(block_mask[0].shape, (2, 2048, 2048))
4670-
self.assertEqual(block_mask[0, 0].shape, (2048, 2048))
4669+
self.assertEqual(block_mask[0].shape, (1, 2, 2048, 2048))
4670+
self.assertEqual(block_mask[0, 0].shape, (1, 1, 2048, 2048))
46714671
self.assertEqual(block_mask.numel(), 4 * 2 * 2048 * 2048)
46724672
self.assertEqual(block_mask.sparsity(), 46.875)
46734673
self.assertEqual(block_mask[0].sparsity(), 46.875)
@@ -4711,13 +4711,26 @@ def causal_mask(b, h, q, kv):
47114711

47124712
# Index on batch dimension
47134713
new_block_mask = block_mask[0]
4714-
assert new_block_mask.kv_num_blocks.shape == (2, 4)
4715-
assert new_block_mask.kv_indices.shape == (2, 4, 4)
4714+
assert new_block_mask.kv_num_blocks.shape == (1, 2, 4)
4715+
assert new_block_mask.kv_indices.shape == (1, 2, 4, 4)
47164716

47174717
# Index on batch and head dimension
47184718
new_block_mask = block_mask[0, 1]
4719-
assert new_block_mask.kv_num_blocks.shape == (4,)
4720-
assert new_block_mask.kv_indices.shape == (4, 4)
4719+
assert new_block_mask.kv_num_blocks.shape == (
4720+
1,
4721+
1,
4722+
4,
4723+
)
4724+
assert new_block_mask.kv_indices.shape == (1, 1, 4, 4)
4725+
4726+
# Index on batch and head dimension with -1 semantics
4727+
new_block_mask = block_mask[-1, -2]
4728+
assert new_block_mask.kv_num_blocks.shape == (
4729+
1,
4730+
1,
4731+
4,
4732+
)
4733+
assert new_block_mask.kv_indices.shape == (1, 1, 4, 4)
47214734

47224735
# slicing on batch and head dimension
47234736
new_block_mask = block_mask[0:2, 1:2]
@@ -5402,7 +5415,7 @@ def test_block_mask_operations_with_none_q_indices(self, device):
54025415
self.assertEqual(block_mask.BLOCK_SIZE, (128, 128))
54035416

54045417
sliced_mask = block_mask[0]
5405-
self.assertEqual(sliced_mask.shape, (1, 128, 512))
5418+
self.assertEqual(sliced_mask.shape, (1, 1, 128, 512))
54065419
self.assertIsNone(sliced_mask.q_indices)
54075420
self.assertIsNone(sliced_mask.q_num_blocks)
54085421

@@ -5412,6 +5425,66 @@ def test_block_mask_operations_with_none_q_indices(self, device):
54125425
self.assertEqual(cpu_mask.kv_num_blocks.device.type, "cpu")
54135426
self.assertIsNone(cpu_mask.q_indices)
54145427

5428+
@supported_platform
5429+
@skip_on_cpu
5430+
def test_broadcasted_head_block_mask(self, device):
5431+
torch.manual_seed(42)
5432+
5433+
def causal_mask(b, h, q_idx, kv_idx):
5434+
return q_idx >= kv_idx
5435+
5436+
def get_mask_mod_with_offset(mask_mod, offset_tensor):
5437+
def _mask_mod(b, h, q, kv):
5438+
return mask_mod(b, h, q + offset_tensor, kv)
5439+
5440+
return _mask_mod
5441+
5442+
B, T, H, D, current_pos = 4, 512, 8, 64, 128
5443+
dtype = torch.float32
5444+
5445+
q = torch.randn(B, H, 1, D, device=device, dtype=dtype)
5446+
k_cache = torch.randn(B, H, T, D, device=device, dtype=dtype)
5447+
v_cache = torch.randn(B, H, T, D, device=device, dtype=dtype)
5448+
5449+
# Keep future tokens tiny to avoid numerical issues when using full caches
5450+
k_cache[:, :, current_pos + 1 :, :] = (
5451+
torch.randn_like(k_cache[:, :, current_pos + 1 :, :]) * 1e-10
5452+
)
5453+
v_cache[:, :, current_pos + 1 :, :] = (
5454+
torch.randn_like(v_cache[:, :, current_pos + 1 :, :]) * 1e-10
5455+
)
5456+
5457+
k_cropped = k_cache[:, :, : current_pos + 1, :]
5458+
v_cropped = v_cache[:, :, : current_pos + 1, :]
5459+
sdpa_output = torch.nn.functional.scaled_dot_product_attention(
5460+
q, k_cropped, v_cropped, attn_mask=None
5461+
)
5462+
5463+
base_mask = create_block_mask(
5464+
causal_mask,
5465+
B=B,
5466+
H=None, # broadcast across heads
5467+
Q_LEN=T,
5468+
KV_LEN=T,
5469+
device=device,
5470+
_compile=True,
5471+
)
5472+
5473+
q_block_size = base_mask.BLOCK_SIZE[0]
5474+
block_offset = current_pos // q_block_size
5475+
mask_slice = base_mask[:, :, block_offset]
5476+
5477+
offset_tensor = torch.tensor(current_pos, device=device)
5478+
mask_slice.mask_mod = get_mask_mod_with_offset(
5479+
base_mask.mask_mod, offset_tensor
5480+
)
5481+
mask_slice.seq_lengths = (1, mask_slice.seq_lengths[1])
5482+
5483+
fa = torch.compile(flex_attention, dynamic=True)
5484+
flex_output = fa(q, k_cache, v_cache, block_mask=mask_slice)
5485+
5486+
self.assertEqual(flex_output, sdpa_output, atol=1e-3, rtol=1e-3)
5487+
54155488

54165489
@large_tensor_test_class("2GB", device=test_device[0])
54175490
class TestPagedAttention(InductorTestCase):

torch/nn/attention/flex_attention.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -649,6 +649,15 @@ def causal_mask(b, h, q_idx, kv_idx):
649649
assert new_block_mask.kv_num_blocks.shape == (2, 1, 1)
650650
assert new_block_mask.kv_indices.shape == (2, 1, 1, 4)
651651
"""
652+
index = (index,) if not isinstance(index, tuple) else index
653+
padded = (*index, slice(None), slice(None), slice(None))[:3]
654+
sizes = self.kv_num_blocks.shape[:3]
655+
index = tuple(
656+
(slice(i + n, i + n + 1) if -n <= i < 0 else slice(i, i + 1))
657+
if isinstance(i, int)
658+
else i
659+
for i, n in zip(padded, sizes)
660+
)
652661
new_kv_num_blocks = self.kv_num_blocks[index]
653662
new_kv_indices = self.kv_indices[index]
654663
if self.full_kv_num_blocks is not None:

0 commit comments

Comments
 (0)