Skip to content

Commit 49046e0

Browse files
pytorchboteqymalfetSkylion007
authored
[cuDNN][SDPA] Check-in test for pytorch#166211 (pytorch#167121)
[cuDNN][SDPA] Check-in test for pytorch#166211 (pytorch#166570) Repros without the neeed for specific tensor data. Should be passing with cuDNN frontend 1.15.0 which current `main` has. Pull Request resolved: pytorch#166570 Approved by: https://github.com/atalman (cherry picked from commit 71a2e93) Co-authored-by: Eddie Yan <[email protected]> Co-authored-by: Nikita Shulga <[email protected]> Co-authored-by: Aaron Gokaslan <[email protected]>
1 parent 4aca6a7 commit 49046e0

File tree

1 file changed

+24
-0
lines changed

1 file changed

+24
-0
lines changed

test/test_transformers.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2855,6 +2855,30 @@ def test_cudnn_attention_seqlen1_dropout_heuristic(self):
28552855
out = torch.nn.functional.scaled_dot_product_attention(q, q, q, dropout_p=0.5)
28562856
out.backward(grad)
28572857

2858+
@skipIfRocm
2859+
@unittest.skipIf(not PLATFORM_SUPPORTS_CUDNN_ATTENTION, "cudnn Attention is not supported on this system")
2860+
def test_cudnn_attention_broken_166211(self):
2861+
# https://github.com/pytorch/pytorch/issues/166211#issue-3551350377
2862+
shape = (20, 4, 4, 32)
2863+
scale = 10
2864+
for i in range(100):
2865+
q = torch.randn(*shape, device='cuda', dtype=torch.bfloat16) * scale
2866+
k = torch.randn(*shape, device='cuda', dtype=torch.bfloat16) * scale
2867+
v = torch.randn(*shape, device='cuda', dtype=torch.bfloat16) * scale
2868+
q.requires_grad = True
2869+
k.requires_grad = True
2870+
v.requires_grad = True
2871+
2872+
grad_attn_output = torch.randn(*shape, device='cuda', dtype=torch.bfloat16) * scale
2873+
2874+
with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.CUDNN_ATTENTION):
2875+
attn_output = torch.nn.functional.scaled_dot_product_attention(q, k, v)
2876+
dq, dk, dv = torch.autograd.grad(outputs=attn_output, inputs=(q, k, v), grad_outputs=grad_attn_output)
2877+
2878+
self.assertFalse(dq.isnan().any())
2879+
self.assertFalse(dk.isnan().any())
2880+
self.assertFalse(dv.isnan().any())
2881+
28582882
@unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Fused SDPA was not built for this system")
28592883
@parametrize("mask_dim", [1, 2, 3, 4])
28602884
def test_mem_efficient_attention_mask_variants(self, device, mask_dim: list[int]):

0 commit comments

Comments
 (0)