Skip to content

Commit ebd29a1

Browse files
shunting314pytorchmergebot
authored andcommitted
[inductor] fuse for scalar shared data (pytorch#162311)
LOAF previously may skip these fusion opportunities and cause some tests fail. Test: - TORCHINDUCTOR_LOOP_ORDERING_AFTER_FUSION=1 python test/inductor/test_torchinductor_strided_blocks.py TritonBlockPointerTestGPU.test_2d_reduction_odd_shapes_view_size4_num_block_pointers_1_num_triton_kernels_1_reduction_op4_cuda Pull Request resolved: pytorch#162311 Approved by: https://github.com/jansel
1 parent 5793dd7 commit ebd29a1

File tree

2 files changed

+52
-16
lines changed

2 files changed

+52
-16
lines changed

test/inductor/test_loop_ordering.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -592,6 +592,23 @@ def f(x, y):
592592
".run(", 1 + int(inductor_config.benchmark_kernel), exactly=True
593593
).run(code[0])
594594

595+
def test_fuse_with_scalar_shared_memory(self):
596+
"""
597+
Make sure if we can fuse two nodes sharing a scalar before,
598+
we can still do it with LOAF applied.
599+
600+
This is not really a big deal. But some tests rely on this and
601+
less number of kernels has some small benefits.
602+
"""
603+
604+
@torch.compile
605+
def f(x):
606+
return torch.mean(x)
607+
608+
x = torch.randn([5, 5], device=GPU_TYPE)
609+
out, code = run_and_get_code(f, x)
610+
FileCheck().check_count("@triton.jit", 1, exactly=True).run(code[0])
611+
595612

596613
@inductor_config.patch(
597614
{

torch/_inductor/scheduler.py

Lines changed: 35 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -309,8 +309,8 @@ def log_details(self) -> None:
309309

310310
def reorder_loops_by_dep_pair(
311311
self, self_dep: MemoryDep, other_dep: MemoryDep
312-
) -> None:
313-
return
312+
) -> bool:
313+
return False
314314

315315
def update_mutated_names(self, renames: dict[str, str]) -> None:
316316
self.mutation_renames = {
@@ -1130,6 +1130,11 @@ def __init__(self, scheduler: Scheduler, node: ir.Operation) -> None:
11301130

11311131

11321132
class SchedulerNode(BaseSchedulerNode):
1133+
"""
1134+
A SchedulerNode is a node for scheduling that encapsulates either
1135+
a ComputedBuffer or a TemplateBuffer.
1136+
"""
1137+
11331138
_sizes: tuple[Sequence[sympy.Expr], ...]
11341139
_body: LoopBody
11351140

@@ -1254,7 +1259,7 @@ def merge_loops(self) -> None:
12541259

12551260
def reorder_loops_by_dep_pair(
12561261
self, self_dep: MemoryDep, other_dep: MemoryDep
1257-
) -> None:
1262+
) -> bool:
12581263
new_order = None
12591264
self_sizes = self._sizes[0]
12601265
if len(self_sizes) == self_dep.num_vars == other_dep.num_vars:
@@ -1266,11 +1271,13 @@ def reorder_loops_by_dep_pair(
12661271
"Reorder loops for %s with order %s", self.get_name(), new_order
12671272
)
12681273
self.apply_new_loop_order(new_order)
1274+
return True
12691275
else:
12701276
loop_ordering_log.debug(
12711277
"Don't reordering %s because we can not decide the suitable loop order",
12721278
self.get_name(),
12731279
)
1280+
return False
12741281

12751282
def debug_str_extra(self) -> str:
12761283
name = self.get_name()
@@ -1527,18 +1534,21 @@ def estimate_flops(self) -> int | None:
15271534

15281535
def reorder_loops_by_dep_pair(
15291536
self, self_dep: MemoryDep, other_dep: MemoryDep
1530-
) -> None:
1537+
) -> bool:
1538+
"""
1539+
Return true if a loop reordering is performed.
1540+
"""
15311541
if self.is_template():
15321542
# We can not really reorder loops for a triton template
1533-
return
1543+
return False
15341544
self_sizes = None
15351545
for snode in self.snodes:
15361546
assert isinstance(snode, SchedulerNode)
15371547
if self_sizes is not None and tuple(self_sizes) != tuple(snode._sizes[0]):
15381548
loop_ordering_log.debug(
15391549
"Can not reorder fused node due to different sizes"
15401550
)
1541-
return
1551+
return False
15421552
self_sizes = snode._sizes[0]
15431553
new_order = None
15441554

@@ -1551,7 +1561,7 @@ def reorder_loops_by_dep_pair(
15511561
"Dont reordering fused node %s because we can not decide the suitable loop order",
15521562
self.get_name(),
15531563
)
1554-
return
1564+
return False
15551565
metrics.num_loop_reordering += 1
15561566
loop_ordering_log.debug(
15571567
"Reorder loops for fused node %s with order %s", self.get_name(), new_order
@@ -1561,6 +1571,7 @@ def reorder_loops_by_dep_pair(
15611571
snode.apply_new_loop_order(new_order)
15621572

15631573
refresh_group_node_dependencies(self)
1574+
return True
15641575

15651576
def __init__(self, scheduler: Scheduler, snodes: list[BaseSchedulerNode]) -> None:
15661577
super().__init__(scheduler)
@@ -3900,21 +3911,26 @@ def shared_data_after_reordering_loop(
39003911
Right now just greedily reorder the loop of node1 to be compatible with node2,
39013912
but ideally we should have some heuristics to reorder the loop for node2
39023913
to be compatible with node1 if that's more efficient.
3914+
3915+
Return the amount of shared data re-computed in this method.
3916+
If no such recomputation happens, return -1 (not return 0 since 0 is a valid
3917+
amount of shared data).
3918+
39033919
"""
39043920

39053921
# TODO Don't do loop reordering for CPU for now.
39063922
# Should debug more why it does not work for CPU codegen
39073923
if not config.loop_ordering_after_fusion or any(
39083924
n.is_cpu() for n in [node1, node2]
39093925
):
3910-
return 0
3926+
return -1
39113927

39123928
node1_buffer_names = node1.read_writes.buffer_names()
39133929
node2_buffer_names = node2.read_writes.buffer_names()
39143930
# Fast path: no common buffers.
39153931
common_buffer_names = node1_buffer_names & node2_buffer_names
39163932
if not common_buffer_names:
3917-
return 0
3933+
return -1
39183934

39193935
node1_name2dep = {dep.name: dep for dep in node1.read_writes.reads_and_writes()}
39203936
node2_name2dep = {dep.name: dep for dep in node2.read_writes.reads_and_writes()}
@@ -3937,13 +3953,13 @@ def shared_data_after_reordering_loop(
39373953
)
39383954

39393955
if len(candidates) == 0:
3940-
return 0
3956+
return -1
39413957

39423958
# Pick the largest buffer to guide the loop reordering
39433959
_numel, lhs_dep, rhs_dep = max(candidates, key=operator.itemgetter(0))
39443960

39453961
if not isinstance(lhs_dep, MemoryDep) or not isinstance(rhs_dep, MemoryDep):
3946-
return 0
3962+
return -1
39473963

39483964
if lhs_dep.num_vars != rhs_dep.num_vars:
39493965
# this can happen due to we don't merge loops.
@@ -3952,21 +3968,22 @@ def shared_data_after_reordering_loop(
39523968
# normalization (merging loops)
39533969
if lhs_dep.normalize() == rhs_dep.normalize():
39543970
return self.dep_size_hint(lhs_dep)
3955-
return 0
3971+
return -1
39563972

3973+
reordered = False
39573974
# Only reorder loops for pointwise for now
39583975
if not node1.is_reduction():
3959-
node1.reorder_loops_by_dep_pair(lhs_dep, rhs_dep)
3976+
reordered = node1.reorder_loops_by_dep_pair(lhs_dep, rhs_dep)
39603977
elif not node2.is_reduction():
3961-
node2.reorder_loops_by_dep_pair(rhs_dep, lhs_dep)
3978+
reordered = node2.reorder_loops_by_dep_pair(rhs_dep, lhs_dep)
39623979
else:
39633980
loop_ordering_log.debug(
39643981
"Don't reorder loops since both nodes are reductions: %s v.s. %s",
39653982
node1.get_name(),
39663983
node2.get_name(),
39673984
)
39683985

3969-
return self.score_fusion_memory(node1, node2)
3986+
return self.score_fusion_memory(node1, node2) if reordered else -1
39703987

39713988
def unfusable_node(self, node: BaseSchedulerNode) -> bool:
39723989
"""
@@ -4255,7 +4272,9 @@ def can_fuse(self, node1: BaseSchedulerNode, node2: BaseSchedulerNode) -> bool:
42554272
shared_data_score < config.score_fusion_memory_threshold
42564273
and config.loop_ordering_after_fusion
42574274
):
4258-
shared_data_score = self.shared_data_after_reordering_loop(node1, node2)
4275+
new_shared_data_score = self.shared_data_after_reordering_loop(node1, node2)
4276+
if new_shared_data_score >= 0:
4277+
shared_data_score = new_shared_data_score
42594278

42604279
if config.expand_dimension_for_pointwise_nodes and (
42614280
expand_analysis := self.get_expand_dim_for_pointwise_nodes(node1, node2)

0 commit comments

Comments
 (0)