Commit 9b2ebc0
Fix reduce_sum_transpose_rule which does a broadcast_in_dim to set out_sharding=operand.aval.to_cotangent_aval().sharding instead of operand.aval.sharding. THis is because if operand is reduced, then on bwd pass, we want the cotangent type to become unreduced.
PiperOrigin-RevId: 8345113021 parent b37b6c0 commit 9b2ebc0
2 files changed
+19
-2
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
7824 | 7824 | | |
7825 | 7825 | | |
7826 | 7826 | | |
7827 | | - | |
7828 | | - | |
| 7827 | + | |
| 7828 | + | |
| 7829 | + | |
7829 | 7830 | | |
7830 | 7831 | | |
7831 | 7832 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
9631 | 9631 | | |
9632 | 9632 | | |
9633 | 9633 | | |
| 9634 | + | |
| 9635 | + | |
| 9636 | + | |
| 9637 | + | |
| 9638 | + | |
| 9639 | + | |
| 9640 | + | |
| 9641 | + | |
| 9642 | + | |
| 9643 | + | |
| 9644 | + | |
| 9645 | + | |
| 9646 | + | |
| 9647 | + | |
| 9648 | + | |
| 9649 | + | |
9634 | 9650 | | |
9635 | 9651 | | |
9636 | 9652 | | |
| |||
0 commit comments