Skip to content

Commit 10b501f

Browse files
pytorchbotdrisspg
andauthored
[Flex] Fix silent correctness w/ backpropping grads (pytorch#164366)
[Flex] Fix silent correctness w/ backpropping grads (pytorch#163677) Fixes #pytorch#162228 # Summary Majority of our tests are only compiling flex-attention in isolation. This means that for fake tensor propagation the input primals and all captured buffers dont do any intermediate computation below autograd. As a result result the by happen chance match the `require_grad`ness of the eager implementation and this check will pass. However if score_mod is a the result of some other intermediate fake tensor prop then it is not guaranteed to have accurate req_gradness, which was happening here. TLDR is that this was a boot and suspenders that was actually harmful and we should just let the joint graph handle creating the correct joint graph Pull Request resolved: pytorch#163677 Approved by: https://github.com/ydwu4 (cherry picked from commit e2ce79e) Co-authored-by: drisspg <[email protected]>
1 parent 31c72b8 commit 10b501f

File tree

2 files changed

+30
-1
lines changed

2 files changed

+30
-1
lines changed

test/inductor/test_flex_attention.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6546,6 +6546,35 @@ def bias_mod(score, b, h, q_idx, kv_idx):
65466546
assert bias.grad, "No gradient computed for bias"
65476547
assert torch.any(bias.grad != 0), "Gradient for bias is 0"
65486548

6549+
@skip_on_cpu
6550+
def test_backprop_error_case(self, device):
6551+
@torch.compile()
6552+
def test(x, y):
6553+
# Materialize a bias matrix
6554+
B, L, device = x.shape[0], x.shape[1], x.device
6555+
b = torch.arange(B, device=device, dtype=torch.long).view(B, 1, 1)
6556+
q_idx = torch.arange(L, device=device, dtype=torch.long).view(1, L, 1)
6557+
kv_idx = torch.arange(L, device=device, dtype=torch.long).view(1, 1, L)
6558+
bias_mat = y[b, q_idx] + y[b, kv_idx] # (B, L, L)
6559+
6560+
# Dummy score_mod retrieving bias values
6561+
def score_mod(score, b, h, q_idx, kv_idx):
6562+
return score + bias_mat[b, q_idx, kv_idx]
6563+
6564+
x_ = x[:, :, None].repeat(1, 1, 16, 1)
6565+
# torch._dynamo.graph_break()
6566+
return flex_attention(x_, x_, x_, score_mod=score_mod)
6567+
6568+
B, L, D = 2, 16, 64
6569+
6570+
x = torch.randn(B, L, D, device=device, requires_grad=True)
6571+
y = torch.randn(B, L, device=device, requires_grad=True)
6572+
6573+
_ = test(x, y).mean().backward()
6574+
6575+
assert x.grad.norm() > 0
6576+
assert y.grad.norm() > 0
6577+
65496578
@skip_on_cpu
65506579
@common_utils.parametrize(
65516580
"params", get_params(device_configs["cuda"].dtypes), name_fn=lambda x: f"{x}"

torch/_higher_order_ops/flex_attention.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1266,7 +1266,7 @@ def flex_attention_backward_fake_tensor_mode(
12661266
[
12671267
(
12681268
torch.empty_like(buffer, memory_format=torch.contiguous_format)
1269-
if isinstance(buffer, torch.Tensor) and buffer.requires_grad
1269+
if isinstance(buffer, torch.Tensor)
12701270
else None
12711271
)
12721272
for buffer in score_mod_other_buffers

0 commit comments

Comments
 (0)