Skip to content

Commit b8b81a9

Browse files
jatayloeellison
andauthored
update mutation renames (pytorch#153895) (#2717)
Upstream cherry pick required to fix dependency issue discovered in max autotune workloads. Thanks to @PaulZhang12 for original find. When we finalize a multi template buffer, we need to reflect mutation renaming in dependencies. Pull Request resolved: pytorch#153895 Approved by: https://github.com/PaulZhang12 (cherry picked from commit 35ddad2) Fixes #ISSUE_NUMBER Co-authored-by: eellison <[email protected]>
1 parent 975f61d commit b8b81a9

File tree

2 files changed

+76
-0
lines changed

2 files changed

+76
-0
lines changed

test/inductor/test_max_autotune.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
# Owner(s): ["module: inductor"]
22
import contextlib
3+
import functools
4+
import inspect
35
import json
6+
import logging
47
import math
58
import os
69
import tempfile
@@ -34,6 +37,7 @@
3437
parametrize,
3538
TEST_WITH_ROCM,
3639
)
40+
from torch.testing._internal.logging_utils import multiple_logs_to_string
3741
from torch.utils._triton import has_triton_tma_device
3842

3943

@@ -928,6 +932,51 @@ def f(x, y):
928932
f_c = torch.compile(mode="max-autotune-no-cudagraphs")(f)
929933
self.assertEqual(f_c(*inps), f(*inps), atol=0.03, rtol=0.25)
930934

935+
@config.patch("trace.enabled", True)
936+
@config.patch({"test_configs.force_extern_kernel_in_multi_template": True})
937+
def test_mutation_rename(self):
938+
torch._logging.set_logs(ir_post_fusion=True)
939+
940+
def f(x, y, z, other):
941+
mul = x * y
942+
diag = torch.diagonal(mul)
943+
diag.copy_(other)
944+
x = torch.mm(mul, z)
945+
y = torch.diagonal(x).add_(torch.tensor(1, device="cuda"))
946+
return y
947+
948+
t = functools.partial(torch.randn, device="cuda")
949+
inps = (t(3, 3), t(3, 3), t(3, 3), t(3))
950+
fn = torch.compile(f, mode="max-autotune-no-cudagraphs")
951+
(
952+
pre_fusion_tream,
953+
post_fusion_stream,
954+
), ctx = multiple_logs_to_string(
955+
"torch._inductor.debug", "ir_pre_fusion", "ir_post_fusion"
956+
)
957+
958+
with config.patch({"trace.debug_dir": tempfile.mkdtemp()}):
959+
with self.assertLogs(
960+
logging.getLogger("torch._inductor.debug"), level=logging.INFO
961+
) as cm, ctx():
962+
out = fn(*inps)
963+
964+
self.assertEqual(f(*inps), out)
965+
966+
pre_fusion_stream = cm.output[0]
967+
post_fusion_stream = cm.output[1]
968+
969+
# before and after finalizing multi template buffer, deps should have the same normalization
970+
# wrt writes
971+
FileCheck().check("MultiTemplateBuffer").check("unmet").check_same("buf1").run(
972+
pre_fusion_stream
973+
)
974+
FileCheck().check("ExternKernelSchedulerNode").check("unmet").check_same(
975+
"buf1"
976+
).run(post_fusion_stream)
977+
978+
torch._logging.set_logs()
979+
931980
@config.patch({"test_configs.force_extern_kernel_in_multi_template": True})
932981
def test_cat_max_autotune_extern(self):
933982
self._test_cat_max_autotune_impl(using_triton_mm=False)

torch/_inductor/scheduler.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2635,6 +2635,15 @@ def benchmark_codegened_module(
26352635
return backend.benchmark_codegened_module(module)
26362636

26372637
def finalize_multi_template_buffers(self) -> None:
2638+
"""
2639+
Finalize a backing choice for MultiTemplateBuffers which did not already have a
2640+
choice finalized through fusion. In the case of an extern choice, this will result
2641+
in replacing the SchedulerNode.
2642+
2643+
If a MultiTemplateBuffer did not have any fusion opportunities, finalizing a choie
2644+
will force completion of compilation and benchmarking.
2645+
"""
2646+
26382647
def replace_operation_buffer(
26392648
orig_node: ir.MultiTemplateBuffer, new_node: ir.OperationBuffer
26402649
) -> None:
@@ -2702,6 +2711,24 @@ def replace_operation_buffer(
27022711
self.name_to_node[node.get_name()] = new_scheduler_node
27032712
self.name_to_fused_node[node.get_name()] = new_scheduler_node
27042713

2714+
# We need to reflect the mutation renames that were recorded in the original node
2715+
mutation_renames = {}
2716+
for dep in itertools.chain(
2717+
node.read_writes.reads, node.unmet_dependencies
2718+
):
2719+
if real_name := self.mutation_real_name.get(dep.name, None):
2720+
mutation_renames[real_name] = dep.name
2721+
2722+
def rename_deps(deps: OrderedSet[Dep]) -> OrderedSet[Dep]:
2723+
return OrderedSet(dep.rename(mutation_renames) for dep in deps)
2724+
2725+
new_scheduler_node.unmet_dependencies = rename_deps(
2726+
new_scheduler_node.unmet_dependencies
2727+
)
2728+
new_scheduler_node.read_writes.reads = rename_deps(
2729+
new_scheduler_node.read_writes.reads
2730+
)
2731+
27052732
for new_out, old_out in zip(
27062733
new_scheduler_node.get_outputs(), node.get_outputs()
27072734
):

0 commit comments

Comments
 (0)