|
48 | 48 | skipCPUIf,
|
49 | 49 | skipCUDAIf,
|
50 | 50 | )
|
51 |
| -from torch.testing._internal.common_utils import IS_FBCODE |
52 | 51 | from torch.utils._triton import has_triton, has_triton_tma_device
|
53 | 52 |
|
54 | 53 |
|
@@ -4340,41 +4339,6 @@ def simple_score_mod(score, b, h, q_idx, kv_idx):
|
4340 | 4339 | fa._FLEX_ATTENTION_DISABLE_COMPILE_DEBUG = original_flag
|
4341 | 4340 | fa._WARNINGS_SHOWN = original_warnings_shown
|
4342 | 4341 |
|
4343 |
| - @largeTensorTest("38GB", "cuda") # emperically |
4344 |
| - @skip_on_cpu |
4345 |
| - @unittest.skipIf(IS_FBCODE, "Skip large tensor test in fbcode") |
4346 |
| - def test_int64_indexing_large_stride(self, device): |
4347 |
| - B = 1 |
4348 |
| - H = 64 |
4349 |
| - S = 2**20 |
4350 |
| - D = 64 |
4351 |
| - dtype = torch.float16 |
4352 |
| - |
4353 |
| - def _simple_causal(b, h, q_idx, kv_idx): |
4354 |
| - return q_idx >= kv_idx |
4355 |
| - |
4356 |
| - BLOCK_M = 1024 |
4357 |
| - BLOCK_N = 1024 |
4358 |
| - |
4359 |
| - block_mask = torch.compile(create_block_mask)( |
4360 |
| - _simple_causal, B, H, S, S, device=device, BLOCK_SIZE=(BLOCK_M, BLOCK_N) |
4361 |
| - ) |
4362 |
| - |
4363 |
| - q = torch.randn(B, H, S, D, device=device, dtype=dtype, requires_grad=True) |
4364 |
| - k = torch.randn(B, H, S, D, device=device, dtype=dtype, requires_grad=True) |
4365 |
| - v = torch.randn(B, H, S, D, device=device, dtype=dtype, requires_grad=True) |
4366 |
| - |
4367 |
| - # Test forward and backward pass |
4368 |
| - out = torch.compile(flex_attention)(q, k, v, block_mask=block_mask) |
4369 |
| - loss = out.sum() |
4370 |
| - loss.backward() |
4371 |
| - |
4372 |
| - # Basic correctness checks, doing full comapre consumes too much memory :/ |
4373 |
| - self.assertEqual(out.shape, (B, H, S, D)) |
4374 |
| - self.assertTrue(q.grad is not None) |
4375 |
| - self.assertTrue(k.grad is not None) |
4376 |
| - self.assertTrue(v.grad is not None) |
4377 |
| - |
4378 | 4342 |
|
4379 | 4343 | class TestBlockMask(InductorTestCase):
|
4380 | 4344 | def setUp(self):
|
|
0 commit comments