Skip to content

Commit 443452c

Browse files
drisspgpytorchmergebot
authored andcommitted
Remove test since it ooms on CI (pytorch#161644)
Pull Request resolved: pytorch#161644 Approved by: https://github.com/BoyuanFeng
1 parent 47ecd20 commit 443452c

File tree

1 file changed

+0
-36
lines changed

1 file changed

+0
-36
lines changed

test/inductor/test_flex_attention.py

Lines changed: 0 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,6 @@
4848
skipCPUIf,
4949
skipCUDAIf,
5050
)
51-
from torch.testing._internal.common_utils import IS_FBCODE
5251
from torch.utils._triton import has_triton, has_triton_tma_device
5352

5453

@@ -4340,41 +4339,6 @@ def simple_score_mod(score, b, h, q_idx, kv_idx):
43404339
fa._FLEX_ATTENTION_DISABLE_COMPILE_DEBUG = original_flag
43414340
fa._WARNINGS_SHOWN = original_warnings_shown
43424341

4343-
@largeTensorTest("38GB", "cuda") # emperically
4344-
@skip_on_cpu
4345-
@unittest.skipIf(IS_FBCODE, "Skip large tensor test in fbcode")
4346-
def test_int64_indexing_large_stride(self, device):
4347-
B = 1
4348-
H = 64
4349-
S = 2**20
4350-
D = 64
4351-
dtype = torch.float16
4352-
4353-
def _simple_causal(b, h, q_idx, kv_idx):
4354-
return q_idx >= kv_idx
4355-
4356-
BLOCK_M = 1024
4357-
BLOCK_N = 1024
4358-
4359-
block_mask = torch.compile(create_block_mask)(
4360-
_simple_causal, B, H, S, S, device=device, BLOCK_SIZE=(BLOCK_M, BLOCK_N)
4361-
)
4362-
4363-
q = torch.randn(B, H, S, D, device=device, dtype=dtype, requires_grad=True)
4364-
k = torch.randn(B, H, S, D, device=device, dtype=dtype, requires_grad=True)
4365-
v = torch.randn(B, H, S, D, device=device, dtype=dtype, requires_grad=True)
4366-
4367-
# Test forward and backward pass
4368-
out = torch.compile(flex_attention)(q, k, v, block_mask=block_mask)
4369-
loss = out.sum()
4370-
loss.backward()
4371-
4372-
# Basic correctness checks, doing full comapre consumes too much memory :/
4373-
self.assertEqual(out.shape, (B, H, S, D))
4374-
self.assertTrue(q.grad is not None)
4375-
self.assertTrue(k.grad is not None)
4376-
self.assertTrue(v.grad is not None)
4377-
43784342

43794343
class TestBlockMask(InductorTestCase):
43804344
def setUp(self):

0 commit comments

Comments
 (0)