|
48 | 48 | skipCPUIf,
|
49 | 49 | skipCUDAIf,
|
50 | 50 | )
|
| 51 | +from torch.testing._internal.common_utils import IS_FBCODE |
51 | 52 | from torch.utils._triton import has_triton, has_triton_tma_device
|
52 | 53 |
|
53 | 54 |
|
@@ -4339,6 +4340,41 @@ def simple_score_mod(score, b, h, q_idx, kv_idx):
|
4339 | 4340 | fa._FLEX_ATTENTION_DISABLE_COMPILE_DEBUG = original_flag
|
4340 | 4341 | fa._WARNINGS_SHOWN = original_warnings_shown
|
4341 | 4342 |
|
| 4343 | + @largeTensorTest("38GB", "cuda") # emperically |
| 4344 | + @skip_on_cpu |
| 4345 | + @unittest.skipIf(IS_FBCODE, "Skip large tensor test in fbcode") |
| 4346 | + def test_int64_indexing_large_stride(self, device): |
| 4347 | + B = 1 |
| 4348 | + H = 64 |
| 4349 | + S = 2**20 |
| 4350 | + D = 64 |
| 4351 | + dtype = torch.float16 |
| 4352 | + |
| 4353 | + def _simple_causal(b, h, q_idx, kv_idx): |
| 4354 | + return q_idx >= kv_idx |
| 4355 | + |
| 4356 | + BLOCK_M = 1024 |
| 4357 | + BLOCK_N = 1024 |
| 4358 | + |
| 4359 | + block_mask = torch.compile(create_block_mask)( |
| 4360 | + _simple_causal, B, H, S, S, device=device, BLOCK_SIZE=(BLOCK_M, BLOCK_N) |
| 4361 | + ) |
| 4362 | + |
| 4363 | + q = torch.randn(B, H, S, D, device=device, dtype=dtype, requires_grad=True) |
| 4364 | + k = torch.randn(B, H, S, D, device=device, dtype=dtype, requires_grad=True) |
| 4365 | + v = torch.randn(B, H, S, D, device=device, dtype=dtype, requires_grad=True) |
| 4366 | + |
| 4367 | + # Test forward and backward pass |
| 4368 | + out = torch.compile(flex_attention)(q, k, v, block_mask=block_mask) |
| 4369 | + loss = out.sum() |
| 4370 | + loss.backward() |
| 4371 | + |
| 4372 | + # Basic correctness checks, doing full comapre consumes too much memory :/ |
| 4373 | + self.assertEqual(out.shape, (B, H, S, D)) |
| 4374 | + self.assertTrue(q.grad is not None) |
| 4375 | + self.assertTrue(k.grad is not None) |
| 4376 | + self.assertTrue(v.grad is not None) |
| 4377 | + |
4342 | 4378 |
|
4343 | 4379 | class TestBlockMask(InductorTestCase):
|
4344 | 4380 | def setUp(self):
|
|
0 commit comments