Skip to content

Commit a8e9ed2

Browse files
shunting314pytorchmergebot
authored andcommitted
[inductor] turn on loaf (for oss) by default (pytorch#162030)
Pull Request resolved: pytorch#162030 Approved by: https://github.com/eellison, https://github.com/jansel
1 parent 0390798 commit a8e9ed2

File tree

3 files changed

+13
-4
lines changed

3 files changed

+13
-4
lines changed

test/dynamo/test_logging.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,11 @@ def test_fusion(self, records):
137137
fn_opt = torch.compile(inductor_schedule_fn, backend="inductor")
138138
fn_opt(torch.ones(1000, 1000, device=device_type))
139139
self.assertGreater(len(records), 0)
140-
self.assertLess(len(records), 8)
140+
141+
# LOAF will add an extra round of fusion and result in more logs
142+
self.assertLess(
143+
len(records), 8 * (1 + torch._inductor.config.loop_ordering_after_fusion)
144+
)
141145

142146
@requires_cuda_and_triton
143147
@make_logging_test(cudagraphs=True)

test/inductor/test_torchinductor_strided_blocks.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1104,6 +1104,8 @@ def foo(x, length):
11041104
# bernoulli operation
11051105
# TODO: fails for triton CPU "Failed to convert to LLVM IR"
11061106
@test_torchinductor.xfail_if_triton_cpu
1107+
# Disable split_reductions on this test for now due to the interaction with LOAF
1108+
@config.patch(split_reductions=False)
11071109
def test_removed_buffers(self):
11081110
from torch.ops import aten
11091111

@@ -1114,8 +1116,8 @@ def fn(a):
11141116
result, code = self._run_and_compare(
11151117
fn,
11161118
*[torch.ones(200, 200, device=self.device) * p],
1117-
expected_num_triton_kernels=2,
1118-
expected_num_block_pointers=3,
1119+
expected_num_triton_kernels=1,
1120+
expected_num_block_pointers=1,
11191121
atol=p * 0.06,
11201122
rtol=0.06,
11211123
)

torch/_inductor/config.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -649,7 +649,10 @@ def use_autoheuristic(name: str) -> bool:
649649
benchmark_fusion: bool = os.environ.get("TORCHINDUCTOR_BENCHMARK_FUSION") == "1"
650650
enabled_metric_tables = os.environ.get("TORCHINDUCTOR_ENABLED_METRIC_TABLES", "")
651651
loop_ordering_after_fusion: bool = (
652-
os.environ.get("TORCHINDUCTOR_LOOP_ORDERING_AFTER_FUSION", "0") == "1"
652+
os.environ.get(
653+
"TORCHINDUCTOR_LOOP_ORDERING_AFTER_FUSION", "0" if is_fbcode() else "1"
654+
)
655+
== "1"
653656
)
654657

655658
# If fusing two nodes only save less then score_fusion_memory_threshold memory,

0 commit comments

Comments
 (0)