Skip to content

Commit 5432966

Browse files
Revert "Remove test since it ooms on CI (pytorch#161644)"
This reverts commit 443452c. Reverted pytorch#161644 on behalf of https://github.com/atalman due to need to revert pytorch#157767 internal tests ([comment](pytorch#161644 (comment)))
1 parent e9975f5 commit 5432966

File tree

1 file changed

+36
-0
lines changed

1 file changed

+36
-0
lines changed

test/inductor/test_flex_attention.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
skipCPUIf,
4949
skipCUDAIf,
5050
)
51+
from torch.testing._internal.common_utils import IS_FBCODE
5152
from torch.utils._triton import has_triton, has_triton_tma_device
5253

5354

@@ -4339,6 +4340,41 @@ def simple_score_mod(score, b, h, q_idx, kv_idx):
43394340
fa._FLEX_ATTENTION_DISABLE_COMPILE_DEBUG = original_flag
43404341
fa._WARNINGS_SHOWN = original_warnings_shown
43414342

4343+
@largeTensorTest("38GB", "cuda") # emperically
4344+
@skip_on_cpu
4345+
@unittest.skipIf(IS_FBCODE, "Skip large tensor test in fbcode")
4346+
def test_int64_indexing_large_stride(self, device):
4347+
B = 1
4348+
H = 64
4349+
S = 2**20
4350+
D = 64
4351+
dtype = torch.float16
4352+
4353+
def _simple_causal(b, h, q_idx, kv_idx):
4354+
return q_idx >= kv_idx
4355+
4356+
BLOCK_M = 1024
4357+
BLOCK_N = 1024
4358+
4359+
block_mask = torch.compile(create_block_mask)(
4360+
_simple_causal, B, H, S, S, device=device, BLOCK_SIZE=(BLOCK_M, BLOCK_N)
4361+
)
4362+
4363+
q = torch.randn(B, H, S, D, device=device, dtype=dtype, requires_grad=True)
4364+
k = torch.randn(B, H, S, D, device=device, dtype=dtype, requires_grad=True)
4365+
v = torch.randn(B, H, S, D, device=device, dtype=dtype, requires_grad=True)
4366+
4367+
# Test forward and backward pass
4368+
out = torch.compile(flex_attention)(q, k, v, block_mask=block_mask)
4369+
loss = out.sum()
4370+
loss.backward()
4371+
4372+
# Basic correctness checks, doing full comapre consumes too much memory :/
4373+
self.assertEqual(out.shape, (B, H, S, D))
4374+
self.assertTrue(q.grad is not None)
4375+
self.assertTrue(k.grad is not None)
4376+
self.assertTrue(v.grad is not None)
4377+
43424378

43434379
class TestBlockMask(InductorTestCase):
43444380
def setUp(self):

0 commit comments

Comments
 (0)