Skip to content

Commit 5990481

Browse files
BoyuanFengmansiag05
authored andcommitted
[Graph Partition] improve custom op output alias (pytorch#163227)
For a custom op with multiple outputs, we will see the following generated code: ``` buf1 = op1(arg0) buf3 = buf0[0] buf4 = buf0[1] del buf1 # <--- if buf1 is not accessed in the future ``` If `buf1` is not accessed in the future, it's good to deallocate early. So we don't delay `del` until both buf3 and buf4 are not used anymore. Note that buf3 and buf4 hold reference to the data such that `del buf1` does not prevent their usage. However, when there are mutating args, we don't see `del buf1` immediately. ```python @torch.library.custom_op( "mylib::op1", mutates_args=["x"], schema="(Tensor(a!)? x) -> (Tensor, Tensor)", device_types="cuda", ) def op1(x) -> tuple[torch.Tensor, torch.Tensor]: x = x + 1 return (x + 1, x + 2) ``` <img width="661" height="821" alt="image" src="https://github.com/user-attachments/assets/3d1d1f5a-9749-4652-bb02-da593c78702d" /> Why? Because `buf3` is a MultiOutput with `buf1` as input and believes `buf1` (an output of FallbackKernel op1) has inputs that alias output. https://github.com/pytorch/pytorch/blob/72fedf05752069c9e8b97c64397aedf6ee2bf5ec/torch/_inductor/ir.py#L7976-L7982 According to `[NOTE: FallbackKernel supported operators]`, as a mutating op that are auto-functionalizable, buf1's output should NOT alias any of the inputs. This PR improves get_inputs_that_alias_output of Fallback Kernel. Use case: [moe custom op in vllm](https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/fused_moe/layer.py#L2057-L2064) Pull Request resolved: pytorch#163227 Approved by: https://github.com/zou3519
1 parent 43c8fa6 commit 5990481

File tree

3 files changed

+78
-4
lines changed

3 files changed

+78
-4
lines changed

test/inductor/test_cudagraph_trees.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3231,6 +3231,60 @@ def fn(x):
32313231
# splitting on 1 custom gives 2 cudagraphs
32323232
self.assertEqual(self.get_manager().new_graph_id().id, 2)
32333233

3234+
@config.patch(implicit_fallbacks=True)
3235+
@torch._inductor.config.patch("graph_partition", True)
3236+
def test_graph_partition_custom_op_mutation_late_free(self):
3237+
@torch.library.custom_op(
3238+
"mylib::op1",
3239+
mutates_args=["x"],
3240+
schema="(Tensor(a!)? x) -> (Tensor, Tensor)",
3241+
device_types="cuda",
3242+
)
3243+
def op1(x) -> tuple[torch.Tensor, torch.Tensor]:
3244+
x = x + 1
3245+
return (x + 1, x + 2)
3246+
3247+
@op1.register_fake
3248+
def _(x) -> tuple[torch.Tensor, torch.Tensor]:
3249+
return (torch.empty_like(x), torch.empty_like(x))
3250+
3251+
@torch.library.custom_op(
3252+
"mylib::cg_unsafe_op",
3253+
mutates_args=[],
3254+
schema="(Tensor x, Tensor y, Tensor x1, Tensor y1) -> Tensor",
3255+
device_types="cuda",
3256+
tags=(torch._C.Tag.cudagraph_unsafe,),
3257+
)
3258+
def cg_unsafe_op(x0, x1, y0, y1) -> torch.Tensor:
3259+
return x0 + x1 + y0 + y1
3260+
3261+
@cg_unsafe_op.register_fake
3262+
def _(x0, x1, y0, y1) -> torch.Tensor:
3263+
return torch.empty_like(x0)
3264+
3265+
def f(x):
3266+
x = x + 1
3267+
x = op1(x)
3268+
x0, x1 = x[0], x[1]
3269+
y0 = x0 + 1
3270+
y1 = x1 + 1
3271+
y = cg_unsafe_op(x0, x1, y0, y1)
3272+
z = y + x0 + x1
3273+
z0, z1 = op1(z)
3274+
z2 = z0 + z1
3275+
res = cg_unsafe_op(z2, z2, y, y)
3276+
return res
3277+
3278+
x = torch.randn(2, 2, device="cuda")
3279+
x_cloned = x.clone()
3280+
eager_out = f(x)
3281+
3282+
f_compiled = torch.compile(f, mode="reduce-overhead")
3283+
3284+
for _ in range(5):
3285+
compiled_out = f_compiled(x_cloned)
3286+
self.assertEqual(eager_out, compiled_out)
3287+
32343288
@config.patch(implicit_fallbacks=True)
32353289
@torch._inductor.config.patch("graph_partition", True)
32363290
def test_graph_partition_custom_op_dynamoc_shapes(self):

test/inductor/test_perf.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1156,11 +1156,13 @@ def f():
11561156
torch.compile(f, fullgraph=True),
11571157
)
11581158

1159-
# Check that we are allocating the minimum number of intermediate buffers
1159+
# Check that we are not allocate intermediate buffers
1160+
# which can be reused.
11601161
matches = re.findall(r"empty_strided_\w+\(", code)
1161-
self.assertEqual(len(matches), 1)
1162+
self.assertEqual(len(matches), 0)
1163+
self.assertEqual("in_out" in code, True)
11621164

1163-
self.assertExpectedInline(count_numel(f), """39""")
1165+
self.assertExpectedInline(count_numel(f), """45""")
11641166

11651167
@requires_cuda_and_triton
11661168
def test_inplace_triton_kernel_v1(self):

torch/_inductor/ir.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7550,7 +7550,25 @@ def has_side_effects(self) -> bool:
75507550
return get_schema_info(self.op_overload).is_mutable()
75517551

75527552
def get_inputs_that_alias_output(self) -> Sequence[str]:
7553-
return self.alias_names
7553+
assert isinstance(
7554+
self.op_overload, (torch._ops.OpOverload, torch._ops.HigherOrderOperator)
7555+
), (
7556+
f"Fails to create FallbackKernel for {self.op_overload}: "
7557+
f"{type(self.op_overload)} not supported"
7558+
)
7559+
7560+
# See [Note: FallbackKernel supported operators]: for a mutating
7561+
# op that is auto-functionalizable, its outputs does NOT
7562+
# alias any of the inputs.
7563+
if (
7564+
not isinstance(self.op_overload, torch._ops.HigherOrderOperator)
7565+
and "_c10d_functional" not in self.op_overload.name()
7566+
and self.op_overload._schema.is_mutable
7567+
and can_auto_functionalize(self.op_overload)
7568+
):
7569+
return []
7570+
else:
7571+
return self.alias_names
75547572

75557573
def get_mutation_names(self) -> Sequence[str]:
75567574
assert len(self.mutation_names) <= 1

0 commit comments

Comments
 (0)