Skip to content

Commit d6d6fa2

Browse files
Revert "bwd pass (pytorch#164504)"
This reverts commit f36f372. Reverted pytorch#164504 on behalf of https://github.com/jeffdaily due to CI had been clean for both cuda and rocm before merge, broke post merge? ([comment](pytorch#164504 (comment)))
1 parent 467c21a commit d6d6fa2

File tree

3 files changed

+26
-370
lines changed

3 files changed

+26
-370
lines changed

test/test_varlen_attention.py

Lines changed: 13 additions & 220 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,11 @@
55
import torch
66
import torch.nn as nn
77
import torch.nn.functional as F
8-
from torch.nn.attention.varlen import varlen_attn
8+
from torch.nn.attention import varlen_attn
99
from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FLASH_ATTENTION
1010
from torch.testing._internal.common_device_type import instantiate_device_type_tests
1111
from torch.testing._internal.common_nn import NNTestCase
1212
from torch.testing._internal.common_utils import parametrize, run_tests
13-
from torch.utils._python_dispatch import TorchDispatchMode
1413

1514

1615
VarlenShape = namedtuple(
@@ -24,18 +23,6 @@
2423
}
2524

2625

27-
class OpLoggingMode(TorchDispatchMode):
28-
"""Logging mode that captures all dispatched operations"""
29-
30-
def __init__(self):
31-
self.called_ops = []
32-
33-
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
34-
op_name = str(func)
35-
self.called_ops.append(op_name)
36-
return func(*args, **(kwargs or {}))
37-
38-
3926
class AttentionBlock(nn.Module):
4027
def __init__(
4128
self, embed_dim: int, num_heads: int, device: torch.device, dtype: torch.dtype
@@ -52,9 +39,12 @@ def __init__(
5239
embed_dim, embed_dim, bias=False, device=device, dtype=dtype
5340
)
5441

55-
def get_varlen_qkv(
42+
def forward_varlen(
5643
self,
5744
x_packed: torch.Tensor,
45+
cu_seq: torch.Tensor,
46+
max_len: int,
47+
is_causal: bool = False,
5848
):
5949
qkv = self.qkv_proj(x_packed)
6050
q, k, v = qkv.chunk(3, dim=-1)
@@ -63,51 +53,24 @@ def get_varlen_qkv(
6353
k = k.view(-1, self.num_heads, self.head_dim)
6454
v = v.view(-1, self.num_heads, self.head_dim)
6555

66-
return q, k, v
67-
68-
def forward_varlen(
69-
self,
70-
x_packed: torch.Tensor,
71-
cu_seq: torch.Tensor,
72-
max_len: int,
73-
is_causal: bool = False,
74-
):
75-
q, k, v = self.get_varlen_qkv(x_packed)
76-
77-
attn_out = varlen_attn(q, k, v, cu_seq, cu_seq, max_len, max_len, is_causal)
56+
attn_out = varlen_attn(
57+
q, k, v, cu_seq, cu_seq, max_len, max_len, is_causal=is_causal
58+
)
7859
attn_out = attn_out.view(-1, self.embed_dim)
7960

8061
return self.out_proj(attn_out)
8162

82-
def forward_sdpa(
83-
self,
84-
x_padded: torch.Tensor,
85-
seq_lengths: torch.Tensor,
86-
dtype: torch.dtype,
87-
is_causal: bool = False,
88-
):
63+
def forward_sdpa(self, x_padded: torch.Tensor, is_causal: bool = False):
8964
batch_size, seq_len, _ = x_padded.shape
9065

9166
qkv = self.qkv_proj(x_padded)
9267
q, k, v = qkv.chunk(3, dim=-1)
9368

94-
mask = (
95-
torch.arange(seq_len, device=x_padded.device)[None, :]
96-
< seq_lengths[:, None]
97-
)
98-
99-
attn_mask = mask[:, None, None, :].expand(
100-
batch_size, self.num_heads, seq_len, seq_len
101-
)
102-
10369
q = q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
10470
k = k.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
10571
v = v.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
10672

107-
attn_out = F.scaled_dot_product_attention(
108-
q, k, v, attn_mask=attn_mask, is_causal=is_causal
109-
)
110-
73+
attn_out = F.scaled_dot_product_attention(q, k, v, is_causal=is_causal)
11174
attn_out = (
11275
attn_out.transpose(1, 2)
11376
.contiguous()
@@ -128,9 +91,7 @@ def create_variable_length_batch(
12891
seq_lengths = torch.tensor(seq_lengths, device=device)
12992
total_tokens = seq_lengths.sum().item()
13093

131-
x_packed = torch.randn(
132-
total_tokens, shape.embed_dim, device=device, dtype=dtype, requires_grad=True
133-
)
94+
x_packed = torch.randn(total_tokens, shape.embed_dim, device=device, dtype=dtype)
13495

13596
cu_seq = torch.zeros(shape.batch_size + 1, device=device, dtype=torch.int32)
13697
cu_seq[1:] = seq_lengths.cumsum(0)
@@ -145,7 +106,6 @@ def create_variable_length_batch(
145106
end_idx = start_idx + seq_len
146107
x_padded[i, :seq_len] = x_packed[start_idx:end_idx]
147108
start_idx = end_idx
148-
x_padded = x_padded.clone().detach().requires_grad_()
149109

150110
return {
151111
"seq_lengths": seq_lengths,
@@ -173,11 +133,7 @@ def test_basic_functionality(self, device, dtype):
173133

174134
total_tokens = shape.batch_size * shape.max_seq_len
175135
x_packed = torch.randn(
176-
total_tokens,
177-
shape.embed_dim,
178-
device=device,
179-
dtype=dtype,
180-
requires_grad=True,
136+
total_tokens, shape.embed_dim, device=device, dtype=dtype
181137
)
182138
cu_seq = torch.tensor(
183139
[0, shape.max_seq_len, total_tokens], device=device, dtype=torch.int32
@@ -191,128 +147,6 @@ def test_basic_functionality(self, device, dtype):
191147
self.assertEqual(output.device, torch.device(device))
192148
self.assertEqual(output.dtype, dtype)
193149

194-
varlen_grad_out = torch.ones_like(output)
195-
196-
varlen_grad = torch.autograd.grad(
197-
outputs=output,
198-
inputs=x_packed,
199-
grad_outputs=varlen_grad_out,
200-
retain_graph=True,
201-
create_graph=False,
202-
allow_unused=False,
203-
)[0]
204-
205-
self.assertIsNotNone(varlen_grad)
206-
self.assertEqual(varlen_grad.shape, x_packed.shape)
207-
self.assertEqual(varlen_grad.dtype, x_packed.dtype)
208-
209-
@unittest.skipIf(
210-
not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Flash Attention not supported"
211-
)
212-
@parametrize("dtype", [torch.bfloat16, torch.float16])
213-
def test_custom_op_compliance(self, device, dtype):
214-
torch.manual_seed(42)
215-
216-
shape = VarlenShape(batch_size=2, max_seq_len=512, embed_dim=1024, num_heads=16)
217-
218-
attention_block = AttentionBlock(
219-
shape.embed_dim, shape.num_heads, device, dtype
220-
)
221-
222-
total_tokens = shape.batch_size * shape.max_seq_len
223-
x_packed = torch.randn(
224-
total_tokens,
225-
shape.embed_dim,
226-
device=device,
227-
dtype=dtype,
228-
)
229-
cu_seq = torch.tensor(
230-
[0, shape.max_seq_len, total_tokens], device=device, dtype=torch.int32
231-
)
232-
233-
q, k, v = attention_block.get_varlen_qkv(x_packed)
234-
235-
torch.library.opcheck(
236-
torch.ops.torch_attn._varlen_attn,
237-
(q, k, v, cu_seq, cu_seq, shape.max_seq_len, shape.max_seq_len, False),
238-
)
239-
240-
out, lse, rng_state = torch.ops.torch_attn._varlen_attn(
241-
q, k, v, cu_seq, cu_seq, shape.max_seq_len, shape.max_seq_len, False
242-
)
243-
grad_out = torch.randn_like(out)
244-
245-
# we don't support double backward
246-
# skipping test_autograd_registration, test_aot_dispatch_dynamic, test_aot_dispatch_static
247-
torch.library.opcheck(
248-
torch.ops.torch_attn._varlen_attn_backward,
249-
(
250-
grad_out,
251-
q,
252-
k,
253-
v,
254-
out,
255-
lse,
256-
cu_seq,
257-
cu_seq,
258-
shape.max_seq_len,
259-
shape.max_seq_len,
260-
False,
261-
rng_state,
262-
),
263-
test_utils=["test_schema", "test_faketensor"],
264-
)
265-
266-
@unittest.skipIf(
267-
not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Flash Attention not supported"
268-
)
269-
@parametrize("dtype", [torch.bfloat16, torch.float16])
270-
def test_custom_op_registration(self, device, dtype):
271-
torch.manual_seed(42)
272-
273-
shape = VarlenShape(batch_size=2, max_seq_len=512, embed_dim=1024, num_heads=16)
274-
275-
attention_block = AttentionBlock(
276-
shape.embed_dim, shape.num_heads, device, dtype
277-
)
278-
279-
total_tokens = shape.batch_size * shape.max_seq_len
280-
x_packed = torch.randn(
281-
total_tokens,
282-
shape.embed_dim,
283-
device=device,
284-
dtype=dtype,
285-
requires_grad=True,
286-
)
287-
cu_seq = torch.tensor(
288-
[0, shape.max_seq_len, total_tokens], device=device, dtype=torch.int32
289-
)
290-
291-
compiled_forward = torch.compile(
292-
attention_block.forward_varlen, backend="eager", fullgraph=True
293-
)
294-
with OpLoggingMode() as mode:
295-
output = compiled_forward(
296-
x_packed, cu_seq, shape.max_seq_len, is_causal=False
297-
)
298-
299-
varlen_grad_out = torch.ones_like(output)
300-
_ = torch.autograd.grad(
301-
outputs=output,
302-
inputs=x_packed,
303-
grad_outputs=varlen_grad_out,
304-
retain_graph=True,
305-
create_graph=False,
306-
allow_unused=False,
307-
)[0]
308-
309-
called_ops = mode.called_ops
310-
311-
custom_ops_called = any(
312-
"torch_attn._varlen_attn" in op for op in called_ops
313-
) and any("torch_attn._varlen_attn_backward" in op for op in called_ops)
314-
assert custom_ops_called
315-
316150
@unittest.skipIf(
317151
not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Flash Attention not supported"
318152
)
@@ -338,10 +172,7 @@ def test_varlen_vs_sdpa(self, device, dtype, is_causal):
338172
is_causal=is_causal,
339173
)
340174
sdpa_output = attention_block.forward_sdpa(
341-
variable_length_batch_data["x_padded"],
342-
variable_length_batch_data["seq_lengths"],
343-
dtype=dtype,
344-
is_causal=is_causal,
175+
variable_length_batch_data["x_padded"], is_causal=is_causal
345176
)
346177

347178
tolerances = default_tolerances[dtype]
@@ -355,44 +186,6 @@ def test_varlen_vs_sdpa(self, device, dtype, is_causal):
355186
torch.testing.assert_close(varlen_seq, sdpa_seq, **tolerances)
356187
start_idx = end_idx
357188

358-
varlen_grad_out = torch.ones_like(varlen_output)
359-
360-
sdpa_grad_out = torch.zeros_like(sdpa_output)
361-
362-
start_idx = 0
363-
for i, seq_len in enumerate(variable_length_batch_data["seq_lengths"]):
364-
end_idx = start_idx + seq_len
365-
sdpa_grad_out[i, :seq_len] = varlen_grad_out[start_idx:end_idx]
366-
start_idx = end_idx
367-
368-
varlen_grad = torch.autograd.grad(
369-
outputs=varlen_output,
370-
inputs=variable_length_batch_data["x_packed"],
371-
grad_outputs=varlen_grad_out,
372-
retain_graph=True,
373-
create_graph=False,
374-
allow_unused=False,
375-
)[0]
376-
377-
sdpa_grad = torch.autograd.grad(
378-
outputs=sdpa_output,
379-
inputs=variable_length_batch_data["x_padded"],
380-
grad_outputs=sdpa_grad_out,
381-
retain_graph=True,
382-
create_graph=False,
383-
allow_unused=False,
384-
)[0]
385-
386-
start_idx = 0
387-
for i, seq_len in enumerate(variable_length_batch_data["seq_lengths"]):
388-
end_idx = start_idx + seq_len
389-
390-
varlen_grad_seq = varlen_grad[start_idx:end_idx]
391-
sdpa_grad_seq = sdpa_grad[i, :seq_len]
392-
393-
torch.testing.assert_close(varlen_grad_seq, sdpa_grad_seq, **tolerances)
394-
start_idx = end_idx
395-
396189

397190
device_types = ("cuda",)
398191

torch/nn/attention/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,14 @@
1414
SDPAParams,
1515
)
1616

17+
from .varlen import varlen_attn
18+
1719

1820
__all__: list[str] = [
1921
"SDPBackend",
2022
"sdpa_kernel",
2123
"WARN_FOR_UNFUSED_KERNELS",
24+
"varlen_attn",
2225
]
2326

2427
# Note: [SDPA warnings]

0 commit comments

Comments
 (0)