Skip to content

Commit dfa2649

Browse files
Revert "[Inductor] Fix epilogue fusion decision with 1 Triton caller as choice (pytorch#156500)"
This reverts commit c48d0f4. Reverted pytorch#156500 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](pytorch#156500 (comment)))
1 parent 5277276 commit dfa2649

File tree

2 files changed

+3
-85
lines changed

2 files changed

+3
-85
lines changed

test/inductor/test_max_autotune.py

Lines changed: 1 addition & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
import re
1111
import tempfile
1212
import unittest
13-
from functools import partial
1413
from typing import Callable, Optional
1514
from unittest import mock
1615
from unittest.mock import MagicMock
@@ -36,11 +35,7 @@
3635
TritonTemplate,
3736
TritonTemplateCaller,
3837
)
39-
from torch._inductor.template_heuristics import (
40-
BaseConfigHeuristic,
41-
CUDAConfigHeuristic,
42-
GemmConfig,
43-
)
38+
from torch._inductor.template_heuristics import CUDAConfigHeuristic, GemmConfig
4439
from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FP8
4540
from torch.testing._internal.common_utils import (
4641
instantiate_parametrized_tests,
@@ -1555,61 +1550,6 @@ def f(a, b):
15551550
if "benchmark_gpu" in counter:
15561551
self.assertEqual(counters["inductor"][counter], 2)
15571552

1558-
@unittest.skipIf(
1559-
not has_triton_tma_device(), "Need device-side TMA support in Triton"
1560-
)
1561-
@config.patch(
1562-
max_autotune=True,
1563-
max_autotune_gemm_backends="TRITON",
1564-
autotune_fallback_to_aten=False,
1565-
)
1566-
def test_one_triton_choice_epilogue_fusion(self):
1567-
"""
1568-
Here we test the fusion case with only 1 Triton choice for mm lowering.
1569-
The hardcoded config itself is valid, but when fused with the torch.float32
1570-
case, the shared memory requirements is higher than the amount available on H100.
1571-
1572-
This test checks that the fusion does not occur in this edge case. This is important
1573-
for future work on lookup table for autotuned gemm configs.
1574-
"""
1575-
1576-
def f(a, b):
1577-
return (a @ b).to(torch.float32)
1578-
1579-
a = torch.randn(512, 1152, device="cuda", dtype=torch.bfloat16)
1580-
b = torch.randn(1152, 7680, device="cuda", dtype=torch.bfloat16)
1581-
1582-
config_heuristic = BaseConfigHeuristic()
1583-
with config.patch(
1584-
{
1585-
"triton.enable_persistent_tma_matmul": "1",
1586-
}
1587-
):
1588-
with (
1589-
mock.patch(
1590-
"torch._inductor.kernel.mm.V.choices.get_base_mm_configs"
1591-
) as base_mm_mock,
1592-
mock.patch(
1593-
"torch._inductor.kernel.mm.V.choices.get_persistent_mm_configs"
1594-
) as persistent_mm_mock,
1595-
):
1596-
base_mm_mock.return_value = partial(
1597-
config_heuristic.preprocess_mm_configs, configs=[]
1598-
)
1599-
persistent_mm_mock.return_value = partial(
1600-
config_heuristic.preprocess_mm_configs,
1601-
configs=[GemmConfig(256, 128, 64, 4, 8, 8)],
1602-
)
1603-
1604-
compiled_f = torch.compile(f)
1605-
out, code = run_and_get_code(compiled_f, a, b)
1606-
1607-
FileCheck().check("triton_tem_fused_mm").check(
1608-
"triton_poi_fused__to_copy"
1609-
).run(code[0])
1610-
1611-
torch.testing.assert_close(out, f(a, b), atol=1e-2, rtol=1e-2)
1612-
16131553

16141554
class TestMaxAutotunePrecompile(TestCase):
16151555
def test_precompilation_threads(self):

torch/_inductor/scheduler.py

Lines changed: 2 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -2831,20 +2831,6 @@ def _any_atomic_add(self, node_list: Sequence[BaseSchedulerNode]) -> bool:
28312831
for n in node_list
28322832
)
28332833

2834-
def _template_upcast(
2835-
self, node1: BaseSchedulerNode, node2: BaseSchedulerNode
2836-
) -> bool:
2837-
# Check if fusing an upcast onto a Triton template. If so, we want to benchmark
2838-
# the fusion to make sure that shared memory requirements are still met
2839-
return (
2840-
isinstance(node1.get_template_node(), ir.TritonTemplateBuffer)
2841-
and node1.node is not None
2842-
and node2.node is not None
2843-
and hasattr(node1.node, "get_dtype")
2844-
and hasattr(node2.node, "get_dtype")
2845-
and node1.node.get_dtype().itemsize < node2.node.get_dtype().itemsize
2846-
)
2847-
28482834
def speedup_by_fusion(
28492835
self, node1: BaseSchedulerNode, node2: BaseSchedulerNode
28502836
) -> Union[bool, Callable[[], bool]]:
@@ -2858,12 +2844,7 @@ def speedup_by_fusion(
28582844
and isinstance(n.get_template_node(), ir.MultiTemplateBuffer)
28592845
for n in (node1, node2)
28602846
)
2861-
2862-
if (
2863-
not self._template_upcast(node1, node2)
2864-
and not config.benchmark_fusion
2865-
and not is_multi_template
2866-
):
2847+
if not config.benchmark_fusion and not is_multi_template:
28672848
return True
28682849

28692850
if (
@@ -3094,10 +3075,7 @@ def benchmark_when_ready() -> bool:
30943075

30953076
except NoTritonConfigsError:
30963077
return False
3097-
except RuntimeError as e:
3098-
if "out of resource" in str(e):
3099-
return False
3100-
raise
3078+
31013079
except CompilationError as e:
31023080
if "Loop-carried variable" in str(e):
31033081
return True

0 commit comments

Comments
 (0)