You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
[while_loop][inductor] fix aliased inputs by cloning (pytorch#160668)
[fx_graph_cse](https://github.com/pytorch/pytorch/blob/main/torch/_functorch/compile_utils.py#L46) is executed in min_cut partitioner which accidentally creates the aliasing for empty buffers and we could see the following graph node for joint graph with cmd: "pytest test/functorch/test_control_flow.py -k test_scan_multiple_layers_gradient_layers_2_device_cpu"
```python
while_loop = torch.ops.higher_order.while_loop(while_loop_cond_graph_0_0, while_loop_body_graph_0_0, (full_default_4, empty_strided_default, full_default_2, full_default_3, full_default_2, full_default_3, full_default, full_default, rev, rev_1, rev_2, rev_3), (primals_4, primals_5, primals_6, primals_7));
```
Notice the operands sequence **"full_default_2, full_default_3, full_default_2, full_default_3, full_default, full_default"**, which indicates the gradient of different layers now sharing the same buffer, which create silent incorrectness.
Fixespytorch#158168.
Pull Request resolved: pytorch#160668
Approved by: https://github.com/zou3519
ghstack dependencies: pytorch#160548, pytorch#160374
0 commit comments