Commit ba86395
[Graph Partition] fix graph partition input signature for fallback kernels (pytorch#166985)
[Graph Partition] fix graph partition input signature for fallback kernels (pytorch#165815)
Scheduler relies on node.last_usage to free buffers. `last_usage` may contain a buffer that is allocated in previous graph partition AND not directly accessed in the current graph partition.
## Example
```python
def f(x):
y = x + 1
z = torch.ops.aten.view.dtype(y, torch.float8_e4m3fn)
z_cpu = z.cpu()
u_cuda = z_cpu.cuda()
return u_cuda
```
In the generated code, we have
```
def partition_0(args):
...
# Topologically Sorted Source Nodes: [y, z], Original ATen: [aten.add, aten.view]
buf1 = torch.ops.aten.view.dtype(buf0, torch.float8_e4m3fn) # < ------ buf1 is a view of buf0
buf2 = buf1 # <------- buf2 is buf1
assert_size_stride(buf2, (8, ), (1, ), 'torch.ops.aten.view.dtype')
assert_alignment(buf2, 16, 'torch.ops.aten.view.dtype')
return (buf2, )
def call(self, args):
...
(buf2,) = self.partitions[0](partition0_args)
...
buf3.copy_(buf2, False)
del buf0
del buf1
del buf2 # <---- `del buf2` leads to `del buf0`. BUT `buf0` is not returned from partition_0.
...
```
Note: view is treated as a fallback kernel due to its special dtype.
https://github.com/pytorch/pytorch/blob/de09bab4b66002a8a9a2195f50f96a78868a3d39/torch/_inductor/lowering.py#L841-L843
## Fix
This PR fixes the issue by also returning these buffers to be freed later.
Pull Request resolved: pytorch#165815
Approved by: https://github.com/eellison
(cherry picked from commit 1891239)
Co-authored-by: Boyuan Feng <[email protected]>1 parent f190bda commit ba86395
File tree
2 files changed
+26
-0
lines changed- test/inductor
- torch/_inductor
2 files changed
+26
-0
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
2805 | 2805 | | |
2806 | 2806 | | |
2807 | 2807 | | |
| 2808 | + | |
| 2809 | + | |
| 2810 | + | |
| 2811 | + | |
| 2812 | + | |
| 2813 | + | |
| 2814 | + | |
| 2815 | + | |
| 2816 | + | |
| 2817 | + | |
| 2818 | + | |
| 2819 | + | |
| 2820 | + | |
| 2821 | + | |
| 2822 | + | |
| 2823 | + | |
2808 | 2824 | | |
2809 | 2825 | | |
2810 | 2826 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
4926 | 4926 | | |
4927 | 4927 | | |
4928 | 4928 | | |
| 4929 | + | |
| 4930 | + | |
| 4931 | + | |
| 4932 | + | |
| 4933 | + | |
| 4934 | + | |
| 4935 | + | |
| 4936 | + | |
| 4937 | + | |
| 4938 | + | |
4929 | 4939 | | |
4930 | 4940 | | |
4931 | 4941 | | |
| |||
0 commit comments