Skip to content

Commit 207f104

Browse files
njriasanpytorchmergebot
authored andcommitted
[Triton] [Inductor] Set default configs for Blackwell Matmul Template (pytorch#163740)
Summary: Sets the default configs for the Blackwell Matmul Templates. Test Plan: NFC Differential Revision: D83116342 Pull Request resolved: pytorch#163740 Approved by: https://github.com/jananisriram
1 parent 3e1b1a3 commit 207f104

File tree

1 file changed

+12
-4
lines changed

1 file changed

+12
-4
lines changed

torch/_inductor/template_heuristics/triton.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -304,6 +304,16 @@ def __init__(self) -> None:
304304
GemmConfig(128, 128, 64, 5, 4),
305305
]
306306

307+
self.blackwell_persistent_mm_configs: list[BaseConfig] = [
308+
GemmConfig(128, 256, 64, 4, 8),
309+
GemmConfig(256, 128, 64, 3, 8),
310+
GemmConfig(128, 256, 128, 2, 8),
311+
GemmConfig(128, 256, 64, 3, 8),
312+
GemmConfig(128, 128, 128, 3, 4),
313+
GemmConfig(256, 128, 64, 3, 8),
314+
GemmConfig(128, 128, 128, 3, 8),
315+
]
316+
307317
self.scaled_mm_configs: list[BaseConfig] = [
308318
GemmConfig(128, 256, 32, 3, 8),
309319
GemmConfig(256, 128, 32, 3, 8),
@@ -2055,8 +2065,7 @@ class CUDABlackwellPersistentTMATemplateConfigHeuristic(
20552065

20562066
def __init__(self) -> None:
20572067
super().__init__()
2058-
# TODO: Tune mm_configs for blackwell.
2059-
self.mm_configs = self.persistent_mm_configs
2068+
self.mm_configs = self.blackwell_persistent_mm_configs
20602069

20612070

20622071
@register_template_heuristic(
@@ -2084,8 +2093,7 @@ class CUDABlackwellAddmmPersistentTMATemplateConfigHeuristic(
20842093

20852094
def __init__(self) -> None:
20862095
super().__init__()
2087-
# TODO: Tune mm_configs for blackwell.
2088-
self.mm_configs = self.persistent_mm_configs
2096+
self.mm_configs = self.blackwell_persistent_mm_configs
20892097

20902098

20912099
@register_template_heuristic(

0 commit comments

Comments
 (0)