Skip to content

Commit ba86395

Browse files
[Graph Partition] fix graph partition input signature for fallback kernels (pytorch#166985)
[Graph Partition] fix graph partition input signature for fallback kernels (pytorch#165815) Scheduler relies on node.last_usage to free buffers. `last_usage` may contain a buffer that is allocated in previous graph partition AND not directly accessed in the current graph partition. ## Example ```python def f(x): y = x + 1 z = torch.ops.aten.view.dtype(y, torch.float8_e4m3fn) z_cpu = z.cpu() u_cuda = z_cpu.cuda() return u_cuda ``` In the generated code, we have ``` def partition_0(args): ... # Topologically Sorted Source Nodes: [y, z], Original ATen: [aten.add, aten.view] buf1 = torch.ops.aten.view.dtype(buf0, torch.float8_e4m3fn) # < ------ buf1 is a view of buf0 buf2 = buf1 # <------- buf2 is buf1 assert_size_stride(buf2, (8, ), (1, ), 'torch.ops.aten.view.dtype') assert_alignment(buf2, 16, 'torch.ops.aten.view.dtype') return (buf2, ) def call(self, args): ... (buf2,) = self.partitions[0](partition0_args) ... buf3.copy_(buf2, False) del buf0 del buf1 del buf2 # <---- `del buf2` leads to `del buf0`. BUT `buf0` is not returned from partition_0. ... ``` Note: view is treated as a fallback kernel due to its special dtype. https://github.com/pytorch/pytorch/blob/de09bab4b66002a8a9a2195f50f96a78868a3d39/torch/_inductor/lowering.py#L841-L843 ## Fix This PR fixes the issue by also returning these buffers to be freed later. Pull Request resolved: pytorch#165815 Approved by: https://github.com/eellison (cherry picked from commit 1891239) Co-authored-by: Boyuan Feng <[email protected]>
1 parent f190bda commit ba86395

File tree

2 files changed

+26
-0
lines changed

2 files changed

+26
-0
lines changed

test/inductor/test_cudagraph_trees.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2805,6 +2805,22 @@ def f(x, y):
28052805
# 2 graph partitions lead to 2 cudagraph
28062806
self.assertEqual(self.get_manager().new_graph_id().id, 2)
28072807

2808+
def test_graph_partition_view_fallback(self):
2809+
def f(x):
2810+
y = x + 1
2811+
z = torch.ops.aten.view.dtype(y, torch.float8_e4m3fn)
2812+
z_cpu = z.cpu()
2813+
u_cuda = z_cpu.cuda()
2814+
return u_cuda
2815+
2816+
compiled_f = torch.compile(f, mode="reduce-overhead")
2817+
2818+
for _ in range(3):
2819+
x = torch.ones(2, dtype=torch.int32, device="cuda")
2820+
eager_out = f(x)
2821+
compiled_out = compiled_f(x)
2822+
self.assertEqual(eager_out, compiled_out)
2823+
28082824
@torch._inductor.config.patch("graph_partition", True)
28092825
def test_graph_partition_log_message(self):
28102826
def foo(x, y):

torch/_inductor/scheduler.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4926,6 +4926,16 @@ def is_none_layout(buf_name: str) -> bool:
49264926
for node in partition:
49274927
buffer_names_to_free.update(node.last_usage)
49284928

4929+
# buffer_names_to_free may contain buffers allocated in previous
4930+
# graph partitions. These buffers should also be a partition
4931+
# input.
4932+
extra_input_names = [
4933+
name
4934+
for name in (buffer_names_to_free - output_names)
4935+
if name in name_to_node
4936+
]
4937+
partition_input_names.update(extra_input_names)
4938+
49294939
input_nodes = {
49304940
name: name_to_node[name]
49314941
for name in partition_input_names

0 commit comments

Comments
 (0)