Commit 1cd83de
[Flex attention] Fix flex attention head broadcast (pytorch#164368)
[Flex attention] Fix flex attention head broadcast (pytorch#163426)
Fixes part of pytorch#163314
In particular bug: **Bug 1: H=None Broadcasting Produces Incorrect Results**
This fixes a shape bug when slicing BlockMask on the Q-tile axis with an int (**mask[:, :, i]**). That form of indexing collapses the Q dimension, so kv_num_blocks/kv_indices lose their expected [B, H, Q_tiles, …] shape. Due to them losing shape, even though the mask_mod remains "interpretable", the kernel’s stride math then reads wrong offsets. Due to this we get silent numerical mismatches compared to regular SDPA, especially when single position decoding/H broadcasting.
The B=None, H=None works case is accidental: with singleton batch/head the kernel maps to index 0 via `sparse_idx_z = off_zq % 1` and `sparse_idx_hq = off_hq % 1` and with a single Q tile `q_start // SPARSE_Q_MULTIPLE = 0`. The missing Q-tiles stride is multiplied by 0, so the bad offset from the collapsed Q axis doesn’t move the pointer and it happens to read the first tile correctly. Once H > 1 or there are multiple Q tiles, those terms become nonzero and the kernel indexes with wrong strides which causes silent error
Pull Request resolved: pytorch#163426
Approved by: https://github.com/drisspg
(cherry picked from commit 1a42656)
Co-authored-by: Isalia20 <[email protected]>1 parent 881c2cc commit 1cd83de
File tree
2 files changed
+89
-7
lines changed- test/inductor
- torch/nn/attention
2 files changed
+89
-7
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
4666 | 4666 | | |
4667 | 4667 | | |
4668 | 4668 | | |
4669 | | - | |
4670 | | - | |
| 4669 | + | |
| 4670 | + | |
4671 | 4671 | | |
4672 | 4672 | | |
4673 | 4673 | | |
| |||
4711 | 4711 | | |
4712 | 4712 | | |
4713 | 4713 | | |
4714 | | - | |
4715 | | - | |
| 4714 | + | |
| 4715 | + | |
4716 | 4716 | | |
4717 | 4717 | | |
4718 | 4718 | | |
4719 | | - | |
4720 | | - | |
| 4719 | + | |
| 4720 | + | |
| 4721 | + | |
| 4722 | + | |
| 4723 | + | |
| 4724 | + | |
| 4725 | + | |
| 4726 | + | |
| 4727 | + | |
| 4728 | + | |
| 4729 | + | |
| 4730 | + | |
| 4731 | + | |
| 4732 | + | |
| 4733 | + | |
4721 | 4734 | | |
4722 | 4735 | | |
4723 | 4736 | | |
| |||
5402 | 5415 | | |
5403 | 5416 | | |
5404 | 5417 | | |
5405 | | - | |
| 5418 | + | |
5406 | 5419 | | |
5407 | 5420 | | |
5408 | 5421 | | |
| |||
5412 | 5425 | | |
5413 | 5426 | | |
5414 | 5427 | | |
| 5428 | + | |
| 5429 | + | |
| 5430 | + | |
| 5431 | + | |
| 5432 | + | |
| 5433 | + | |
| 5434 | + | |
| 5435 | + | |
| 5436 | + | |
| 5437 | + | |
| 5438 | + | |
| 5439 | + | |
| 5440 | + | |
| 5441 | + | |
| 5442 | + | |
| 5443 | + | |
| 5444 | + | |
| 5445 | + | |
| 5446 | + | |
| 5447 | + | |
| 5448 | + | |
| 5449 | + | |
| 5450 | + | |
| 5451 | + | |
| 5452 | + | |
| 5453 | + | |
| 5454 | + | |
| 5455 | + | |
| 5456 | + | |
| 5457 | + | |
| 5458 | + | |
| 5459 | + | |
| 5460 | + | |
| 5461 | + | |
| 5462 | + | |
| 5463 | + | |
| 5464 | + | |
| 5465 | + | |
| 5466 | + | |
| 5467 | + | |
| 5468 | + | |
| 5469 | + | |
| 5470 | + | |
| 5471 | + | |
| 5472 | + | |
| 5473 | + | |
| 5474 | + | |
| 5475 | + | |
| 5476 | + | |
| 5477 | + | |
| 5478 | + | |
| 5479 | + | |
| 5480 | + | |
| 5481 | + | |
| 5482 | + | |
| 5483 | + | |
| 5484 | + | |
| 5485 | + | |
| 5486 | + | |
| 5487 | + | |
5415 | 5488 | | |
5416 | 5489 | | |
5417 | 5490 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
649 | 649 | | |
650 | 650 | | |
651 | 651 | | |
| 652 | + | |
| 653 | + | |
| 654 | + | |
| 655 | + | |
| 656 | + | |
| 657 | + | |
| 658 | + | |
| 659 | + | |
| 660 | + | |
652 | 661 | | |
653 | 662 | | |
654 | 663 | | |
| |||
0 commit comments