Commit 10b501f
[Flex] Fix silent correctness w/ backpropping grads (pytorch#164366)
[Flex] Fix silent correctness w/ backpropping grads (pytorch#163677)
Fixes #pytorch#162228
# Summary
Majority of our tests are only compiling flex-attention in isolation. This means that for fake tensor propagation the input primals and all captured buffers dont do any intermediate computation below autograd. As a result result the by happen chance match the `require_grad`ness of the eager implementation and this check will pass. However if score_mod is a the result of some other intermediate fake tensor prop then it is not guaranteed to have accurate req_gradness, which was happening here.
TLDR is that this was a boot and suspenders that was actually harmful and we should just let the joint graph handle creating the correct joint graph
Pull Request resolved: pytorch#163677
Approved by: https://github.com/ydwu4
(cherry picked from commit e2ce79e)
Co-authored-by: drisspg <[email protected]>1 parent 31c72b8 commit 10b501f
File tree
2 files changed
+30
-1
lines changed- test/inductor
- torch/_higher_order_ops
2 files changed
+30
-1
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
6546 | 6546 | | |
6547 | 6547 | | |
6548 | 6548 | | |
| 6549 | + | |
| 6550 | + | |
| 6551 | + | |
| 6552 | + | |
| 6553 | + | |
| 6554 | + | |
| 6555 | + | |
| 6556 | + | |
| 6557 | + | |
| 6558 | + | |
| 6559 | + | |
| 6560 | + | |
| 6561 | + | |
| 6562 | + | |
| 6563 | + | |
| 6564 | + | |
| 6565 | + | |
| 6566 | + | |
| 6567 | + | |
| 6568 | + | |
| 6569 | + | |
| 6570 | + | |
| 6571 | + | |
| 6572 | + | |
| 6573 | + | |
| 6574 | + | |
| 6575 | + | |
| 6576 | + | |
| 6577 | + | |
6549 | 6578 | | |
6550 | 6579 | | |
6551 | 6580 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
1266 | 1266 | | |
1267 | 1267 | | |
1268 | 1268 | | |
1269 | | - | |
| 1269 | + | |
1270 | 1270 | | |
1271 | 1271 | | |
1272 | 1272 | | |
| |||
0 commit comments