Skip to content

Commit 0390798

Browse files
njriasanpytorchmergebot
authored andcommitted
[Triton] [Inductor] Enable Epilogue Subtiling in the blackwell ws template (pytorch#163145)
Summary: Enables support for epilogue subtiling in the blackwell ws template. This requires the ability to call `store_output` twice in the same kernel and reuse the same tensor descriptor across allocations. Test Plan: Tested with test_max_autotune.py on a Blackwell server. Rollback Plan: Differential Revision: D82610077 Pull Request resolved: pytorch#163145 Approved by: https://github.com/eellison
1 parent 124dd36 commit 0390798

File tree

7 files changed

+131
-57
lines changed

7 files changed

+131
-57
lines changed

test/inductor/test_max_autotune.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -271,12 +271,14 @@ def next_multiple_16(a: int) -> int:
271271
@parametrize("b_transposed", (False, True))
272272
@parametrize("dynamic", (False, True))
273273
@parametrize("tma_store", (False, True))
274+
@parametrize("epilogue_subtile", (False, True))
274275
def test_blackwell_max_autotune_regular_mm_persistent_tma(
275276
self,
276277
a_transposed: bool,
277278
b_transposed: bool,
278279
dynamic: bool,
279280
tma_store: bool,
281+
epilogue_subtile: bool,
280282
):
281283
def mm(a, b):
282284
# TMA requires 16-byte alignment: here we repeat the dims
@@ -308,13 +310,15 @@ def mm(a, b):
308310
"max_autotune": True,
309311
"triton.enable_persistent_tma_matmul": True,
310312
"triton.enable_template_tma_store": tma_store,
313+
"triton.enable_epilogue_subtiling": epilogue_subtile,
311314
"test_configs.autotune_choice_name_regex": "blackwell_ws_persistent_device_tma",
312315
}
313316
):
314317
c_actual, code = run_and_get_code(torch.compile(mm, dynamic=dynamic), a, b)
315318
c_expected = mm(a, b)
316319

317320
torch.testing.assert_close(c_actual, c_expected, atol=1e-2, rtol=1e-2)
321+
write_count = 2 if epilogue_subtile else 1
318322
if tma_store:
319323
# Verify that we are using a TMA implementation
320324
# Note: The tma_descriptor0 is generated by the kernel. If the
@@ -324,7 +328,9 @@ def mm(a, b):
324328
write_api = "tl.store"
325329
FileCheck().check("triton_tem_fused_mm").check(
326330
"triton.language.make_tensor_descriptor"
327-
).check("tl.load_tensor_descriptor").check(write_api).run(code[0])
331+
).check("tl.load_tensor_descriptor").check_count(write_api, write_count).run(
332+
code[0]
333+
)
328334

329335
@unittest.skipIf(
330336
not has_triton_tma_device(), "Need device-side TMA support in Triton"
@@ -652,12 +658,14 @@ def addmm(x, a, b):
652658
@parametrize("b_transposed", (False, True))
653659
@parametrize("dynamic", (False, True))
654660
@parametrize("tma_store", (False, True))
661+
@parametrize("epilogue_subtile", (False, True))
655662
def test_blackwell_max_autotune_addmm_persistent_tma(
656663
self,
657664
a_transposed: bool,
658665
b_transposed: bool,
659666
dynamic: bool,
660667
tma_store: bool,
668+
epilogue_subtile: bool,
661669
):
662670
def addmm(x, a, b):
663671
# TMA requires 16-byte alignment: here we repeat the dims
@@ -692,6 +700,7 @@ def addmm(x, a, b):
692700
"max_autotune": True,
693701
"triton.enable_persistent_tma_matmul": True,
694702
"triton.enable_template_tma_store": tma_store,
703+
"triton.enable_epilogue_subtiling": epilogue_subtile,
695704
"test_configs.autotune_choice_name_regex": "blackwell_ws_persistent_device_tma",
696705
}
697706
):
@@ -702,6 +711,7 @@ def addmm(x, a, b):
702711

703712
make_desc_api = "triton.language.make_tensor_descriptor"
704713
read_api = "tl.load_tensor_descriptor"
714+
write_count = 2 if epilogue_subtile else 1
705715
if tma_store:
706716
# Verify that we are using a TMA implementation
707717
# Note: The tma_descriptor0 is generated by the kernel. If the
@@ -713,7 +723,7 @@ def addmm(x, a, b):
713723
# Verify that we are using a TMA implementation
714724
FileCheck().check("triton_tem_fused_addmm").check(make_desc_api).check(
715725
read_api
716-
).check(write_api).run(code[0])
726+
).check_count(write_api, write_count).run(code[0])
717727

718728
torch.testing.assert_close(c_actual, c_expected, atol=1e-2, rtol=1e-2)
719729

torch/_inductor/codegen/simd.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -414,6 +414,8 @@ def __init__(
414414
)
415415
self.no_x_dim = self.want_no_x_dim()
416416
self.code_hash: Optional[str] = None
417+
# Info to enable multiple store_output calls for epilogue subtiling
418+
self.store_output_ctr = itertools.count()
417419

418420
# define this in a closure to make cache local to object
419421
@functools.cache
@@ -427,6 +429,14 @@ def simplify_indexing(index: sympy.Expr):
427429
self.simplify_indexing = simplify_indexing
428430
self.initialize_range_tree(pid_cache)
429431

432+
def _get_store_output_subgraph_name(self, i: int) -> str:
433+
return f"<STORE_OUTPUT_{i}>"
434+
435+
def get_store_output_count(self):
436+
total = next(self.store_output_ctr)
437+
self.store_output_ctr = itertools.count(start=total - 1, step=1)
438+
return total
439+
430440
@property
431441
@cache_property_on_self
432442
def num_reduction_dims(self) -> int:
@@ -1605,10 +1615,13 @@ def _codegen_single_template(
16051615

16061616
partial_code = render()
16071617

1608-
with kernel.set_subgraph_body("<STORE_OUTPUT>"):
1609-
for node in epilogue_nodes:
1610-
node.codegen(kernel.split_and_set_ranges(node.get_ranges()))
1611-
kernel.cse.invalidate(OrderedSet())
1618+
num_store_subgraphs = kernel.get_store_output_count()
1619+
for i in range(num_store_subgraphs):
1620+
subgraph_name = kernel._get_store_output_subgraph_name(i)
1621+
with kernel.set_subgraph_body(subgraph_name):
1622+
for node in epilogue_nodes:
1623+
node.codegen(kernel.split_and_set_ranges(node.get_ranges()))
1624+
kernel.cse.invalidate(OrderedSet())
16121625

16131626
for input_name, buffer in kernel.named_input_nodes.items():
16141627
subgraph_name = f"<LOAD_INPUT_{input_name}>"
@@ -1656,9 +1669,10 @@ def _codegen_single_template(
16561669
subgraph_name = f"<LOAD_INPUT_{input_name}>"
16571670
partial_code.finalize_hook(subgraph_name, strict=False)
16581671

1659-
with kernel.set_subgraph_body("<STORE_OUTPUT>"):
1660-
if not isinstance(partial_code, str):
1661-
partial_code.finalize_hook("<STORE_OUTPUT>")
1672+
num_store_subgraphs = kernel.get_store_output_count()
1673+
for i in range(num_store_subgraphs):
1674+
subgraph_name = kernel._get_store_output_subgraph_name(i)
1675+
partial_code.finalize_hook(subgraph_name)
16621676

16631677
if isinstance(partial_code, str):
16641678
src_code = partial_code

torch/_inductor/codegen/triton.py

Lines changed: 42 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1944,6 +1944,8 @@ def __init__(
19441944
self.fixed_config = fixed_config
19451945
super().__init__(tiling, **kwargs)
19461946
self.cse = TritonCSE(self.newvar_prefix, self.suffix)
1947+
# Cache of values that can be reused for the prologue.
1948+
self.prologue_cache: dict[str, str] = {}
19471949
self.prologue: IndentedBuffer = IndentedBuffer()
19481950
self.post_loop_combine: IndentedBuffer = IndentedBuffer()
19491951
self.post_loop_store: IndentedBuffer = IndentedBuffer()
@@ -2485,42 +2487,49 @@ def codegen_block_ptr(
24852487
and self.range_trees[-1].is_loop
24862488
and indexing.has_rindex()
24872489
) or indexing.can_lift:
2488-
block_descriptor_id = next(self.block_ptr_id)
2489-
if isinstance(indexing, BlockPtrOptions):
2490-
block_descriptor = f"block_ptr{block_descriptor_id}"
2490+
if indexing.can_lift and var in self.prologue_cache:
2491+
# Check for epilogue subtiling to reuse the same
2492+
# tensor descriptor.
2493+
block_descriptor = self.prologue_cache[var]
24912494
else:
2492-
block_descriptor = f"tma_descriptor{block_descriptor_id}"
2493-
line_body = DeferredLine(
2494-
name, f"{block_descriptor} = {indexing.format(var, roffset=False)}"
2495-
)
2496-
if indexing.can_lift:
2497-
self.prologue.writeline(line_body)
2498-
else:
2499-
self.body.writeline(line_body)
2500-
2501-
if isinstance(indexing, BlockPtrOptions):
2502-
# Store for later use. If the buffer is removed the below advancements
2503-
# are no longer necessary
2504-
self.block_ptr_to_buffer[block_descriptor] = name
2495+
block_descriptor_id = next(self.block_ptr_id)
2496+
if isinstance(indexing, BlockPtrOptions):
2497+
block_descriptor = f"block_ptr{block_descriptor_id}"
2498+
else:
2499+
block_descriptor = f"tma_descriptor{block_descriptor_id}"
2500+
line_body = DeferredLine(
2501+
name, f"{block_descriptor} = {indexing.format(var, roffset=False)}"
2502+
)
2503+
if indexing.can_lift:
2504+
self.prologue.writeline(line_body)
2505+
# Cache the descriptor for epilogue subtiling
2506+
self.prologue_cache[var] = block_descriptor
2507+
else:
2508+
self.body.writeline(line_body)
25052509

2506-
# Generate block pointer advancements, for later use.
2507-
for symt in TritonSymbols.reduction_types:
2508-
advance_offsets = indexing.advance_roffset(symt)
2510+
if isinstance(indexing, BlockPtrOptions):
2511+
# Store for later use. If the buffer is removed the below advancements
2512+
# are no longer necessary
2513+
self.block_ptr_to_buffer[block_descriptor] = name
2514+
2515+
# Generate block pointer advancements, for later use.
2516+
for symt in TritonSymbols.reduction_types:
2517+
advance_offsets = indexing.advance_roffset(symt)
2518+
2519+
# Ignore identity advancements.
2520+
if all(
2521+
V.graph.sizevars.statically_known_equals(
2522+
offset, sympy.Integer(0)
2523+
)
2524+
for offset in advance_offsets
2525+
):
2526+
continue
25092527

2510-
# Ignore identity advancements.
2511-
if all(
2512-
V.graph.sizevars.statically_known_equals(
2513-
offset, sympy.Integer(0)
2528+
advancements = self.pointer_advancements[symt]
2529+
assert block_descriptor not in advancements, (
2530+
f"duplicate advancement for pointer '{block_descriptor}' at type '{symt}'"
25142531
)
2515-
for offset in advance_offsets
2516-
):
2517-
continue
2518-
2519-
advancements = self.pointer_advancements[symt]
2520-
assert block_descriptor not in advancements, (
2521-
f"duplicate advancement for pointer '{block_descriptor}' at type '{symt}'"
2522-
)
2523-
advancements[block_descriptor] = advance_offsets
2532+
advancements[block_descriptor] = advance_offsets
25242533
else:
25252534
block_descriptor = indexing.format(var)
25262535
return block_descriptor, other
@@ -3879,6 +3888,7 @@ def codegen_prologue(self, code: IndentedBuffer):
38793888

38803889
code.splice(self.prologue)
38813890
self.prologue.clear()
3891+
self.prologue_cache.clear()
38823892

38833893
def codegen_body(self):
38843894
"""

torch/_inductor/config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1440,6 +1440,8 @@ class triton:
14401440
# Should TMA store be enable from templates. TODO: Remove once we
14411441
# can autotune over the result.
14421442
enable_template_tma_store = os.environ.get("ENABLE_TEMPLATE_TMA_STORE", "0") == "1"
1443+
# Use epilogue subtiling. We allow disabling it due to limited B200 testing.
1444+
enable_epilogue_subtiling = os.environ.get("ENABLE_EPILOGUE_SUBTILING", "1") == "1"
14431445
# Skip L1 cache for buffers that are used only once. Disabled by default
14441446
skip_l1_cache = os.environ.get("TORCHINDUCTOR_SKIP_L1", "0") == "1"
14451447

torch/_inductor/kernel/mm.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -645,14 +645,35 @@ def _compute_pid(tile_id, num_pid_in_group, grid_m, GROUP_M: tl.constexpr, NUM_S
645645
)
646646
offs_cm = pid_m * BLOCK_M
647647
offs_cn = pid_n * BLOCK_N
648-
# TODO: Add EPILOGUE_SUBTILE
648+
{%- if EPILOGUE_SUBTILE %}
649+
tl.static_assert(BLOCK_N % 2 == 0)
650+
acc = tl.reshape(accumulator, (BLOCK_M, 2, BLOCK_N // 2))
651+
acc = tl.permute(acc, (0, 2, 1))
652+
acc0, acc1 = tl.split(acc)
653+
{{store_output(
654+
("offs_cm", "offs_cn"),
655+
"acc0",
656+
indent_width=8,
657+
val_shape=("BLOCK_M", "BLOCK_N // 2"),
658+
block_indexing=True
659+
)}}
660+
offs_cn2 = offs_cn + BLOCK_N // 2
661+
{{store_output(
662+
("offs_cm", "offs_cn2"),
663+
"acc1",
664+
indent_width=8,
665+
val_shape=("BLOCK_M", "BLOCK_N // 2"),
666+
block_indexing=True
667+
)}}
668+
{%- else %}
649669
{{store_output(
650670
("offs_cm", "offs_cn"),
651671
"accumulator",
652672
indent_width=8,
653673
val_shape=("BLOCK_M", "BLOCK_N"),
654674
block_indexing=True
655675
)}}
676+
{%- endif %}
656677
"""
657678

658679
blackwell_ws_persistent_device_tma_mm_template = TritonTemplate(

torch/_inductor/select_algorithm.py

Lines changed: 24 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,8 @@
107107

108108
from torch._inductor.codegen.simd import IterationRangesRoot
109109

110+
from .codegen.common import CSE
111+
110112

111113
class KernelNamespace:
112114
pass
@@ -261,13 +263,14 @@ class SubgraphInfo:
261263
loads: IndentedBuffer = dataclasses.field(default_factory=IndentedBuffer)
262264
stores: IndentedBuffer = dataclasses.field(default_factory=IndentedBuffer)
263265
ops_handler: Optional[V.WrapperHandler] = None # type: ignore[name-defined]
266+
cse: Optional["CSE[Any]"] = None
264267

265268
# only copied over if not None
266269
range_trees: Optional[list["IterationRangesRoot"]] = None
267270
numels: Optional[dict[str, sympy.Expr]] = None
268271

269272
def __post_init__(self):
270-
self.only_copy_if_non_none_fields = ("range_trees", "numels")
273+
self.only_copy_if_non_none_fields = ("range_trees", "numels", "cse")
271274

272275
def to_dict(self):
273276
return {
@@ -557,12 +560,10 @@ def set_subgraph_body(self, body_name: str):
557560
setattr(self, key, value)
558561

559562
@contextlib.contextmanager
560-
def create_subgraph_body(self, body_name: str):
563+
def create_subgraph_body(self, body_name: str, clear_cse: bool = False):
561564
assert body_name not in self.subgraph_bodies
562565
self.subgraph_bodies[body_name] = SubgraphInfo(
563-
IndentedBuffer(),
564-
None,
565-
None,
566+
IndentedBuffer(), None, None, cse=self.cse.clone() if clear_cse else None
566567
)
567568
with self.set_subgraph_body(body_name):
568569
yield
@@ -1071,7 +1072,13 @@ def _generate_index_from_tma_index(
10711072
# XBLOCK/YBLOCK and xoffset/yoffset. We append XBLOCK/YBLOCK
10721073
# to the top of the kernel so we can safely extract the tensor
10731074
# descriptor construction to the top of the kernel.
1074-
self.defines += f"{block_name}: tl.constexpr = {block_size}\n"
1075+
if block_name in self.prologue_cache:
1076+
assert self.prologue_cache[block_name] == block_size, (
1077+
f"Constant {block_name} must be used for all stores"
1078+
)
1079+
else:
1080+
self.prologue_cache[block_name] = block_size
1081+
self.prologue.writeline(f"{block_name}: tl.constexpr = {block_size}")
10751082
else:
10761083
block_name = block_size
10771084
line0 = f"{offset_name} = {texpr(tma_index)}"
@@ -1124,7 +1131,10 @@ def store_output(
11241131
block_indexing (bool): Are the input indices presented as offsets for creating the block (e.g.
11251132
inputs to TMA) or are they tensors that should be passed in directly.
11261133
"""
1127-
with self.create_subgraph_body("<STORE_OUTPUT>"):
1134+
subgraph_name = self._get_store_output_subgraph_name(
1135+
next(self.store_output_ctr)
1136+
)
1137+
with self.create_subgraph_body(subgraph_name, clear_cse=True):
11281138
assert isinstance(indices, (list, tuple))
11291139
assert isinstance(val, str)
11301140
assert isinstance(mask, (str, type(None)))
@@ -1300,13 +1310,14 @@ def store_output(
13001310
self.codegen_body()
13011311

13021312
def hook():
1303-
# more stuff might have been added since the codegen_body above
1304-
self.codegen_body()
1305-
self.cse.invalidate(OrderedSet())
1313+
with self.set_subgraph_body(subgraph_name):
1314+
# more stuff might have been added since the codegen_body above
1315+
self.codegen_body()
1316+
self.cse.invalidate(OrderedSet())
13061317

1307-
return textwrap.indent(self.body.getvalue(), " " * indent_width).strip()
1318+
return textwrap.indent(self.body.getvalue(), " " * indent_width).strip()
13081319

1309-
return self._register_hook("<STORE_OUTPUT>", hook)
1320+
return self._register_hook(subgraph_name, hook)
13101321

13111322
def _register_hook(
13121323
self,
@@ -1812,8 +1823,7 @@ def make_extra() -> str:
18121823

18131824
try:
18141825
template = kernel.render(self.template, kwargs, caching_enabled)
1815-
with kernel.set_subgraph_body("<STORE_OUTPUT>"):
1816-
code = template.finalize_all()
1826+
code = template.finalize_all()
18171827
except ZeroDivisionError:
18181828
# TODO(nmacchioni): fix sympy division by zero
18191829
return None

0 commit comments

Comments
 (0)