Skip to content

Commit 1a42656

Browse files
Isalia20pytorchmergebot
authored andcommitted
[Flex attention] Fix flex attention head broadcast (pytorch#163426)
Fixes part of pytorch#163314 In particular bug: **Bug 1: H=None Broadcasting Produces Incorrect Results** This fixes a shape bug when slicing BlockMask on the Q-tile axis with an int (**mask[:, :, i]**). That form of indexing collapses the Q dimension, so kv_num_blocks/kv_indices lose their expected [B, H, Q_tiles, …] shape. Due to them losing shape, even though the mask_mod remains "interpretable", the kernel’s stride math then reads wrong offsets. Due to this we get silent numerical mismatches compared to regular SDPA, especially when single position decoding/H broadcasting. The B=None, H=None works case is accidental: with singleton batch/head the kernel maps to index 0 via `sparse_idx_z = off_zq % 1` and `sparse_idx_hq = off_hq % 1` and with a single Q tile `q_start // SPARSE_Q_MULTIPLE = 0`. The missing Q-tiles stride is multiplied by 0, so the bad offset from the collapsed Q axis doesn’t move the pointer and it happens to read the first tile correctly. Once H > 1 or there are multiple Q tiles, those terms become nonzero and the kernel indexes with wrong strides which causes silent error Pull Request resolved: pytorch#163426 Approved by: https://github.com/drisspg
1 parent bda9ab2 commit 1a42656

File tree

2 files changed

+89
-7
lines changed

2 files changed

+89
-7
lines changed

test/inductor/test_flex_attention.py

Lines changed: 80 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4672,8 +4672,8 @@ def causal_mask(b, h, q, kv):
46724672

46734673
block_mask = create_block_mask(causal_mask, 4, 2, 2048, 2048, device=device)
46744674
self.assertEqual(block_mask.shape, (4, 2, 2048, 2048))
4675-
self.assertEqual(block_mask[0].shape, (2, 2048, 2048))
4676-
self.assertEqual(block_mask[0, 0].shape, (2048, 2048))
4675+
self.assertEqual(block_mask[0].shape, (1, 2, 2048, 2048))
4676+
self.assertEqual(block_mask[0, 0].shape, (1, 1, 2048, 2048))
46774677
self.assertEqual(block_mask.numel(), 4 * 2 * 2048 * 2048)
46784678
self.assertEqual(block_mask.sparsity(), 46.875)
46794679
self.assertEqual(block_mask[0].sparsity(), 46.875)
@@ -4717,13 +4717,26 @@ def causal_mask(b, h, q, kv):
47174717

47184718
# Index on batch dimension
47194719
new_block_mask = block_mask[0]
4720-
assert new_block_mask.kv_num_blocks.shape == (2, 4)
4721-
assert new_block_mask.kv_indices.shape == (2, 4, 4)
4720+
assert new_block_mask.kv_num_blocks.shape == (1, 2, 4)
4721+
assert new_block_mask.kv_indices.shape == (1, 2, 4, 4)
47224722

47234723
# Index on batch and head dimension
47244724
new_block_mask = block_mask[0, 1]
4725-
assert new_block_mask.kv_num_blocks.shape == (4,)
4726-
assert new_block_mask.kv_indices.shape == (4, 4)
4725+
assert new_block_mask.kv_num_blocks.shape == (
4726+
1,
4727+
1,
4728+
4,
4729+
)
4730+
assert new_block_mask.kv_indices.shape == (1, 1, 4, 4)
4731+
4732+
# Index on batch and head dimension with -1 semantics
4733+
new_block_mask = block_mask[-1, -2]
4734+
assert new_block_mask.kv_num_blocks.shape == (
4735+
1,
4736+
1,
4737+
4,
4738+
)
4739+
assert new_block_mask.kv_indices.shape == (1, 1, 4, 4)
47274740

47284741
# slicing on batch and head dimension
47294742
new_block_mask = block_mask[0:2, 1:2]
@@ -5408,7 +5421,7 @@ def test_block_mask_operations_with_none_q_indices(self, device):
54085421
self.assertEqual(block_mask.BLOCK_SIZE, (128, 128))
54095422

54105423
sliced_mask = block_mask[0]
5411-
self.assertEqual(sliced_mask.shape, (1, 128, 512))
5424+
self.assertEqual(sliced_mask.shape, (1, 1, 128, 512))
54125425
self.assertIsNone(sliced_mask.q_indices)
54135426
self.assertIsNone(sliced_mask.q_num_blocks)
54145427

@@ -5418,6 +5431,66 @@ def test_block_mask_operations_with_none_q_indices(self, device):
54185431
self.assertEqual(cpu_mask.kv_num_blocks.device.type, "cpu")
54195432
self.assertIsNone(cpu_mask.q_indices)
54205433

5434+
@supported_platform
5435+
@skip_on_cpu
5436+
def test_broadcasted_head_block_mask(self, device):
5437+
torch.manual_seed(42)
5438+
5439+
def causal_mask(b, h, q_idx, kv_idx):
5440+
return q_idx >= kv_idx
5441+
5442+
def get_mask_mod_with_offset(mask_mod, offset_tensor):
5443+
def _mask_mod(b, h, q, kv):
5444+
return mask_mod(b, h, q + offset_tensor, kv)
5445+
5446+
return _mask_mod
5447+
5448+
B, T, H, D, current_pos = 4, 512, 8, 64, 128
5449+
dtype = torch.float32
5450+
5451+
q = torch.randn(B, H, 1, D, device=device, dtype=dtype)
5452+
k_cache = torch.randn(B, H, T, D, device=device, dtype=dtype)
5453+
v_cache = torch.randn(B, H, T, D, device=device, dtype=dtype)
5454+
5455+
# Keep future tokens tiny to avoid numerical issues when using full caches
5456+
k_cache[:, :, current_pos + 1 :, :] = (
5457+
torch.randn_like(k_cache[:, :, current_pos + 1 :, :]) * 1e-10
5458+
)
5459+
v_cache[:, :, current_pos + 1 :, :] = (
5460+
torch.randn_like(v_cache[:, :, current_pos + 1 :, :]) * 1e-10
5461+
)
5462+
5463+
k_cropped = k_cache[:, :, : current_pos + 1, :]
5464+
v_cropped = v_cache[:, :, : current_pos + 1, :]
5465+
sdpa_output = torch.nn.functional.scaled_dot_product_attention(
5466+
q, k_cropped, v_cropped, attn_mask=None
5467+
)
5468+
5469+
base_mask = create_block_mask(
5470+
causal_mask,
5471+
B=B,
5472+
H=None, # broadcast across heads
5473+
Q_LEN=T,
5474+
KV_LEN=T,
5475+
device=device,
5476+
_compile=True,
5477+
)
5478+
5479+
q_block_size = base_mask.BLOCK_SIZE[0]
5480+
block_offset = current_pos // q_block_size
5481+
mask_slice = base_mask[:, :, block_offset]
5482+
5483+
offset_tensor = torch.tensor(current_pos, device=device)
5484+
mask_slice.mask_mod = get_mask_mod_with_offset(
5485+
base_mask.mask_mod, offset_tensor
5486+
)
5487+
mask_slice.seq_lengths = (1, mask_slice.seq_lengths[1])
5488+
5489+
fa = torch.compile(flex_attention, dynamic=True)
5490+
flex_output = fa(q, k_cache, v_cache, block_mask=mask_slice)
5491+
5492+
self.assertEqual(flex_output, sdpa_output, atol=1e-3, rtol=1e-3)
5493+
54215494

54225495
@large_tensor_test_class("2GB", device=test_device[0])
54235496
class TestPagedAttention(InductorTestCase):

torch/nn/attention/flex_attention.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -648,6 +648,15 @@ def causal_mask(b, h, q_idx, kv_idx):
648648
assert new_block_mask.kv_num_blocks.shape == (2, 1, 1)
649649
assert new_block_mask.kv_indices.shape == (2, 1, 1, 4)
650650
"""
651+
index = (index,) if not isinstance(index, tuple) else index
652+
padded = (*index, slice(None), slice(None), slice(None))[:3]
653+
sizes = self.kv_num_blocks.shape[:3]
654+
index = tuple(
655+
(slice(i + n, i + n + 1) if -n <= i < 0 else slice(i, i + 1))
656+
if isinstance(i, int)
657+
else i
658+
for i, n in zip(padded, sizes)
659+
)
651660
new_kv_num_blocks = self.kv_num_blocks[index]
652661
new_kv_indices = self.kv_indices[index]
653662
if self.full_kv_num_blocks is not None:

0 commit comments

Comments
 (0)