Skip to content

Commit bc5111c

Browse files
andyanwangpytorchmergebot
authored andcommitted
[Inductor] Prevent kernel fusion with too many unique inputs and outputs (pytorch#166275)
MTIA triton currently has a limit that it can't support the cases when there are too many input/output buffers. This PR adds the limitation to prevent large fusion with many input/output buffer. Differential Revision: [D85509351](https://our.internmc.facebook.com/intern/diff/D85509351/) Pull Request resolved: pytorch#166275 Approved by: https://github.com/eellison ghstack dependencies: pytorch#166274
1 parent 398fdd3 commit bc5111c

File tree

4 files changed

+140
-0
lines changed

4 files changed

+140
-0
lines changed

test/inductor/test_inductor_scheduler.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
# Owner(s): ["module: inductor"]
22

33
from unittest import skipIf
4+
from unittest.mock import Mock
45

56
import torch
67
import torch._inductor.metrics as metrics
78
import torch.utils.flop_counter
89
from torch._dynamo.utils import counters
10+
from torch._inductor.dependencies import Dep, ReadWrites
11+
from torch._inductor.scheduler import BaseSchedulerNode, Scheduler
912
from torch._inductor.utils import fresh_inductor_cache
1013
from torch.testing._internal.common_cuda import SM70OrLater
1114
from torch.testing._internal.common_device_type import (
@@ -15,6 +18,7 @@
1518
)
1619
from torch.testing._internal.common_utils import parametrize, run_tests, TestCase
1720
from torch.testing._internal.inductor_utils import IS_BIG_GPU
21+
from torch.utils._ordered_set import OrderedSet
1822

1923

2024
def FlopCounterMode(*args, **kwargs):
@@ -132,6 +136,79 @@ def test_flop_counter_op(self, device, dtype, options):
132136
counters["inductor"]["flop_count"] = 0
133137
torch._logging.set_logs()
134138

139+
def test_fusion_prevent_too_many_reads_and_writes_prevents_fusion(self):
140+
"""Test that fusion is prevented when unique I/O buffers exceed threshold"""
141+
# Setup: Create nodes with many unique I/O buffers
142+
# node1: reads [A, B, C], writes [D]
143+
# node2: reads [D, E, F], writes [G]
144+
# D becomes internal (node2 reads node1's write)
145+
# After fusion: unique I/O = {A, B, C, E, F, G} = 6 buffers
146+
scheduler = Mock(spec=Scheduler)
147+
scheduler.can_buffer_be_removed_through_fusion = Mock(return_value=False)
148+
149+
node1 = self._create_mock_node(
150+
name="node1", reads=["A", "B", "C"], writes=["D"]
151+
)
152+
node2 = self._create_mock_node(
153+
name="node2", reads=["D", "E", "F"], writes=["G"]
154+
)
155+
156+
# Execute: Check with threshold of 5 (should prevent fusion since 6 > 5)
157+
result = Scheduler.fusion_prevent_too_many_reads_and_writes(
158+
scheduler, node1, node2, threshold=5
159+
)
160+
161+
# Assert: Fusion should be prevented (6 unique buffers > 5 threshold)
162+
self.assertTrue(result)
163+
164+
def test_fusion_prevent_too_many_reads_and_writes_allows_fusion(self):
165+
"""Test that fusion is allowed when intermediate buffers are removed"""
166+
# Setup: Create nodes where node2 reads node1's output
167+
# node1: reads [A, B], writes [C]
168+
# node2: reads [C, D], writes [E]
169+
# C becomes internal (node2 reads node1's write)
170+
# After fusion: unique I/O = {A, B, D, E} = 4 buffers
171+
scheduler = Mock(spec=Scheduler)
172+
scheduler.can_buffer_be_removed_through_fusion = Mock(return_value=False)
173+
174+
node1 = self._create_mock_node(name="node1", reads=["A", "B"], writes=["C"])
175+
node2 = self._create_mock_node(name="node2", reads=["C", "D"], writes=["E"])
176+
177+
# Execute: Check with threshold of 5 (should allow fusion since 4 <= 5)
178+
result = Scheduler.fusion_prevent_too_many_reads_and_writes(
179+
scheduler, node1, node2, threshold=5
180+
)
181+
182+
# Assert: Fusion should be allowed (4 unique buffers <= 5 threshold)
183+
self.assertFalse(result)
184+
185+
def _create_mock_node(self, name: str, reads: list[str], writes: list[str]) -> Mock:
186+
"""Helper method to create a mock scheduler node with specified reads/writes"""
187+
node = Mock(spec=BaseSchedulerNode)
188+
node.get_name = Mock(return_value=name)
189+
node.get_nodes = Mock(return_value=[node])
190+
191+
# Create mock Dep objects for reads and writes
192+
read_deps = OrderedSet()
193+
for read_name in reads:
194+
dep = Mock(spec=Dep)
195+
dep.name = read_name
196+
read_deps.add(dep)
197+
198+
write_deps = OrderedSet()
199+
for write_name in writes:
200+
dep = Mock(spec=Dep)
201+
dep.name = write_name
202+
write_deps.add(dep)
203+
204+
# Create mock ReadWrites object
205+
read_writes = Mock(spec=ReadWrites)
206+
read_writes.reads = read_deps
207+
read_writes.writes = write_deps
208+
209+
node.read_writes = read_writes
210+
return node
211+
135212

136213
instantiate_device_type_tests(TestScheduler, globals())
137214

torch/_inductor/choices.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -530,6 +530,17 @@ def can_fuse(
530530
WhyNoFuse(node1, node2)("Fusion will increase peak memory")
531531
return False
532532

533+
if (
534+
config.max_fusion_unique_io_buffers is not None
535+
and scheduler.fusion_prevent_too_many_reads_and_writes(
536+
node1,
537+
node2,
538+
config.max_fusion_unique_io_buffers,
539+
)
540+
):
541+
WhyNoFuse(node1, node2)("fusion_prevent_too_many_reads_and_writes")
542+
return False
543+
533544
return True
534545

535546
@staticmethod

torch/_inductor/config.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -688,6 +688,10 @@ def use_autoheuristic(name: str) -> bool:
688688
# how many nodes to attempt pairwise fusion with in a buffer group
689689
max_fusion_buffer_group_pairwise_attempts = 64
690690

691+
# maximum number of unique input/output buffers allowed in fused kernels.
692+
# The check is disabled if set to None.
693+
max_fusion_unique_io_buffers: Optional[int] = None
694+
691695
# max number of inputs to generate cat as a pointwise op with masked loads
692696
max_pointwise_cat_inputs = 8
693697

torch/_inductor/scheduler.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4113,6 +4113,54 @@ def _find_single_user_inputs(
41134113
return True
41144114
return False
41154115

4116+
def fusion_prevent_too_many_reads_and_writes(
4117+
self, node1: BaseSchedulerNode, node2: BaseSchedulerNode, threshold: int
4118+
) -> bool:
4119+
# After fusion, we need to calculate the unique I/O buffers
4120+
# accounting for buffers that become internal (removed through fusion)
4121+
4122+
# Get all nodes that will be in the fused node
4123+
fused_node_names = OrderedSet(
4124+
[node.get_name() for node in node1.get_nodes()]
4125+
+ [node.get_name() for node in node2.get_nodes()]
4126+
)
4127+
4128+
# Calculate node2 reads that can be removed through fusion,
4129+
# i.e. node2 reads that are outputs of node1
4130+
node1_write_names = OrderedSet(dep.name for dep in node1.read_writes.writes)
4131+
node2_read_names = OrderedSet(dep.name for dep in node2.read_writes.reads)
4132+
reads_removed_through_fusion = node2_read_names & node1_write_names
4133+
4134+
# Calculate node1 writes that can be removed through fusion,
4135+
# i.e. node1 writes that are only read by node2
4136+
writes_removed_through_fusion: OrderedSet[str] = OrderedSet()
4137+
for write_dep in node1.read_writes.writes:
4138+
if self.can_buffer_be_removed_through_fusion(
4139+
write_dep.name, fused_node_names
4140+
):
4141+
writes_removed_through_fusion.add(write_dep.name)
4142+
4143+
# Get all unique reads (union of both nodes' reads)
4144+
all_read_names = OrderedSet(
4145+
dep.name for dep in node1.read_writes.reads
4146+
) | OrderedSet(dep.name for dep in node2.read_writes.reads)
4147+
4148+
# Get all unique writes (union of both nodes' writes)
4149+
all_write_names = OrderedSet(
4150+
dep.name for dep in node1.read_writes.writes
4151+
) | OrderedSet(dep.name for dep in node2.read_writes.writes)
4152+
4153+
# Remove reads that become internal
4154+
unique_reads = all_read_names - reads_removed_through_fusion
4155+
4156+
# Remove writes that become internal
4157+
unique_writes = all_write_names - writes_removed_through_fusion
4158+
4159+
# Get all unique buffer names (reads and writes combined, but no double counting)
4160+
unique_io_buffers = unique_reads | unique_writes
4161+
4162+
return len(unique_io_buffers) > threshold
4163+
41164164
def are_long_distant_nodes(
41174165
self, node1: BaseSchedulerNode, node2: BaseSchedulerNode
41184166
) -> bool:

0 commit comments

Comments
 (0)