Skip to content

Commit 625108e

Browse files
coconutrubenpytorchmergebot
authored andcommitted
[inductor] consolidate common GEMM triton param retrieval (pytorch#159383)
\# Why - Make loop iteration simpler - Have a common spot where to make modifications that affect all the GEMM Triton templates, avoiding missed spots \# What - pull out commong logic of taking the BaseConfig objects and turning them into kwargs to feed into maybe_append_choice for Triton GEMM templates Differential Revision: [D79186962](https://our.internmc.facebook.com/intern/diff/D79186962) Pull Request resolved: pytorch#159383 Approved by: https://github.com/jansel
1 parent 09e5a93 commit 625108e

File tree

10 files changed

+1345
-467
lines changed

10 files changed

+1345
-467
lines changed

test/inductor/test_max_autotune.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,10 @@
3535
TritonTemplate,
3636
TritonTemplateCaller,
3737
)
38-
from torch._inductor.template_heuristics import CUDAConfigHeuristic, GemmConfig
38+
from torch._inductor.template_heuristics import (
39+
CUDAMMTemplateConfigHeuristic,
40+
GemmConfig,
41+
)
3942
from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FP8
4043
from torch.testing._internal.common_utils import (
4144
instantiate_parametrized_tests,
@@ -1173,7 +1176,7 @@ def f(a, b):
11731176
# Force only decomposeK choice
11741177
with (
11751178
mock.patch(
1176-
"torch._inductor.kernel.mm.V.choices.get_base_mm_configs"
1179+
"torch._inductor.kernel.mm.V.choices.get_mm_configs"
11771180
) as base_mm_mock,
11781181
mock.patch(
11791182
"torch._inductor.kernel.mm.use_decompose_k_choice"
@@ -1561,9 +1564,9 @@ def f(a, b):
15611564
b = torch.randn(K, N, dtype=torch.float16, device="cuda", requires_grad=True)
15621565

15631566
with mock.patch(
1564-
"torch._inductor.kernel.mm.V.choices.get_config_heuristics"
1567+
"torch._inductor.template_registry.get_template_heuristic"
15651568
) as config_mock:
1566-
config_heuristics = CUDAConfigHeuristic()
1569+
config_heuristics = CUDAMMTemplateConfigHeuristic()
15671570

15681571
# Traditionally, this would be set of all possible configs
15691572
# We mock out the code path for the sake of the unit test

torch/_inductor/choices.py

Lines changed: 34 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
from . import config
1111
from .codecache import write_text
12+
from .kernel_inputs import KernelInputs # noqa: TC001
1213
from .metrics import get_metric_table, is_metric_table_enabled
1314
from .runtime.hints import DeviceProperties, ReductionHint
1415
from .scheduler import BaseSchedulerNode, Scheduler, WhyNoFuse
@@ -20,6 +21,7 @@
2021
ROCmConfigHeuristic,
2122
XPUConfigHeuristic,
2223
)
24+
from .template_registry import get_template_heuristic
2325
from .virtualized import V
2426

2527

@@ -71,58 +73,6 @@ def get_config_heuristics(
7173
else:
7274
return BaseConfigHeuristic()
7375

74-
# GEMM configs
75-
def get_base_mm_configs(
76-
self, device_type: Optional[str] = "cuda"
77-
) -> partial[Generator[TritonConfig, None, None]]:
78-
mm_heuristics = self.get_config_heuristics(device_type)
79-
if config.max_autotune_gemm_search_space != "EXHAUSTIVE":
80-
return mm_heuristics.get_mm_configs()
81-
else:
82-
return mm_heuristics.get_exhaustive_mm_configs()
83-
84-
def get_extra_mm_configs(
85-
self, device_type: Optional[str] = "cuda"
86-
) -> partial[Generator[TritonConfig, None, None]]:
87-
mm_heuristics = self.get_config_heuristics(device_type)
88-
return mm_heuristics.get_extra_mm_configs()
89-
90-
def get_int8_mm_configs(
91-
self, device_type: Optional[str] = "cuda"
92-
) -> partial[Generator[TritonConfig, None, None]]:
93-
mm_heuristics = self.get_config_heuristics(device_type)
94-
return mm_heuristics.get_int8_mm_configs()
95-
96-
def get_mixed_mm_configs(
97-
self, device_type: Optional[str] = "cuda"
98-
) -> partial[Generator[TritonConfig, None, None]]:
99-
mm_heuristics = self.get_config_heuristics(device_type)
100-
return mm_heuristics.get_mixed_mm_configs()
101-
102-
def get_persistent_mm_configs(
103-
self, device_type: Optional[str] = "cuda"
104-
) -> partial[Generator[TritonConfig, None, None]]:
105-
mm_heuristics = self.get_config_heuristics(device_type)
106-
return mm_heuristics.get_persistent_mm_configs()
107-
108-
def get_scaled_mm_configs(
109-
self, device_type: Optional[str] = "cuda"
110-
) -> partial[Generator[TritonConfig, None, None]]:
111-
mm_heuristics = self.get_config_heuristics(device_type)
112-
return mm_heuristics.get_scaled_mm_configs()
113-
114-
def get_scaled_persistent_mm_configs(
115-
self, device_type: Optional[str] = "cuda"
116-
) -> partial[Generator[TritonConfig, None, None]]:
117-
mm_heuristics = self.get_config_heuristics(device_type)
118-
return mm_heuristics.get_scaled_persistent_mm_configs()
119-
120-
def get_mm_plus_mm_configs(
121-
self, device_type: Optional[str] = "cuda"
122-
) -> partial[Generator[TritonConfig, None, None]]:
123-
mm_heuristics = self.get_config_heuristics(device_type)
124-
return mm_heuristics.get_mm_plus_mm_configs()
125-
12676
# Conv configs
12777
def get_conv_configs(
12878
self, device_type: Optional[str] = "cuda"
@@ -131,6 +81,7 @@ def get_conv_configs(
13181
return conv_heuristics.get_conv_configs()
13282

13383
# Flex attention configs
84+
# TODO(coconutruben): break out flexattention/decode configs into the new retrieval mechanism
13485
def get_flex_attention_fwd_configs(
13586
self, head_dim: int, dtype: torch.dtype, device_type: Optional[str] = "cuda"
13687
) -> list[Any]:
@@ -149,6 +100,37 @@ def get_flex_decode_configs(
149100
flex_heuristics = self.get_config_heuristics(device_type)
150101
return flex_heuristics.get_flex_decode_configs(head_dim, dtype)
151102

103+
def get_mm_configs(
104+
self,
105+
kernel_inputs: KernelInputs,
106+
layout: Any,
107+
template_name: str,
108+
op_name: str,
109+
) -> Generator[dict[str, Any], None, None]:
110+
"""
111+
Get generator of template parameters for MM templates using template-specific heuristics.
112+
113+
Args:
114+
kernel_inputs: MMKernelInputs containing input tensor nodes and matrix indices
115+
layout: Output layout
116+
template_name: Template name (e.g., "bmm", "mm", "mm_persistent_tma")
117+
op_name: Operation name (e.g., "bmm", "baddbmm", "addmm", "mm_plus_mm")
118+
119+
Yields:
120+
Template parameter dictionaries ready for maybe_append_choice
121+
"""
122+
input_tensors = kernel_inputs.nodes()
123+
if len(input_tensors) < 2:
124+
raise ValueError(f"Need at least 2 input tensors, got {len(input_tensors)}")
125+
126+
# Extract device_type from kernel_inputs
127+
device_type = kernel_inputs.device_type
128+
assert device_type is not None, "get_mm_configs requires a valid device type"
129+
# Get the appropriate template-specific heuristic
130+
heuristic = get_template_heuristic(template_name, device_type, op_name)
131+
132+
yield from heuristic.get_template_configs(kernel_inputs, layout, op_name)
133+
152134
def triton_kernel_kwargs(
153135
self,
154136
kernel_cls: type[TritonKernel],

torch/_inductor/kernel/bmm.py

Lines changed: 33 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from torch._inductor.codegen.rocm.ck_universal_gemm_template import CKGemmTemplate
77

88
from .. import ir, lowering as L
9+
from ..kernel_inputs import MMKernelInputs
910
from ..select_algorithm import (
1011
autotune_select_algorithm,
1112
ExternKernelChoice,
@@ -26,8 +27,6 @@
2627
addmm_epilogue,
2728
is_batch_stride_largest,
2829
mm_args,
29-
mm_config_kwargs,
30-
mm_options,
3130
)
3231

3332

@@ -40,13 +39,6 @@ def bmm_grid(b, m, n, meta, *, cdiv):
4039
return (cdiv(m, meta["BLOCK_M"]) * cdiv(n, meta["BLOCK_N"]), b, 1)
4140

4241

43-
def _is_large_block_for_cpu(m, n, k):
44-
# Thresholds are experimentally determined to reduce Triton CPU compile times
45-
if m > 128 or n > 128 or k > 128:
46-
return True
47-
return m * n > 2**12
48-
49-
5042
bmm_template = TritonTemplate(
5143
name="bmm",
5244
grid=bmm_grid,
@@ -175,9 +167,14 @@ def may_require_contiguous(t, meta_t):
175167
meta_mat2 = V.graph.current_node.args[1]
176168
mat2 = may_require_contiguous(mat2, meta_mat2)
177169

170+
# TODO(coconutruben): integrate into MMKernelInputs when all callsites use that
178171
m, n, k, layout, mat1, mat2 = mm_args(
179172
mat1, mat2, layout=layout, out_dtype=out_dtype
180173
)
174+
name = "bmm"
175+
176+
# Create MMKernelInputs for BMM at the top
177+
kernel_inputs = MMKernelInputs([mat1, mat2])
181178

182179
# below is for getting an overview logging info of inductor mms
183180
batch_size = mat1.get_size()[0] # Extract batch dimension
@@ -195,63 +192,65 @@ def may_require_contiguous(t, meta_t):
195192

196193
if out_dtype:
197194
assert mat1.get_device().type == "cuda", "out_dtype is only supported for CUDA"
198-
aten_func = aten_bmm_dtype.bind((mat1, mat2), layout, out_dtype=out_dtype)
195+
aten_func = aten_bmm_dtype.bind(
196+
kernel_inputs.nodes(), layout, out_dtype=out_dtype
197+
)
199198
else:
200-
aten_func = aten_bmm.bind((mat1, mat2), layout)
199+
aten_func = aten_bmm.bind(kernel_inputs.nodes(), layout)
201200

202201
# options to tune from
203202
choices = [aten_func] if use_aten_gemm_kernels() else []
204203

205-
device_type = ir.get_device_type(mat1)
206-
bmm_configs = V.choices.get_base_mm_configs(device_type)
207-
208-
dtype = mat1.get_dtype()
209204
if use_triton_template(layout):
210205
# TODO: add out_dtype support for Triton Template
211206
assert out_dtype is None, "out_dtype is not supported for Triton"
212-
for config in bmm_configs(
213-
m,
214-
n,
215-
k,
216-
**mm_config_kwargs(device_type, _is_large_block_for_cpu, dtype.itemsize),
207+
208+
for kwargs in V.choices.get_mm_configs(
209+
kernel_inputs, layout, bmm_template.name, name
217210
):
218211
bmm_template.maybe_append_choice(
219212
choices,
220-
input_nodes=(mat1, mat2),
213+
input_nodes=kernel_inputs.nodes(),
221214
layout=layout,
222-
**mm_options(config, m, n, k, layout),
215+
**kwargs,
223216
)
224217
_, is_nonzero = _is_static_problem(layout)
225218
batch_stride_largest = is_batch_stride_largest(mat1, mat2, layout)
226219
if (
227220
batch_stride_largest
228221
and is_nonzero
229222
and use_cutlass_template(layout, m, n, k)
230-
and _use_cutlass_for_op("bmm")
223+
and _use_cutlass_for_op(name)
231224
):
232225
from ..codegen.cuda.gemm_template import CUTLASS3xGemmTemplate
233226

234-
CUTLASS3xGemmTemplate.add_cutlass_gemm_choices(choices, layout, [mat1, mat2]) # type: ignore[arg-type]
227+
CUTLASS3xGemmTemplate.add_cutlass_gemm_choices(
228+
choices, layout, kernel_inputs.nodes()
229+
) # type: ignore[arg-type]
235230

236231
if use_cpp_bmm_template(layout, mat1, mat2):
237232
from ..codegen.cpp_bmm_template import CppBmmTemplate
238233

239234
CppBmmTemplate.add_choices(
240235
choices,
241236
layout,
242-
[mat1, mat2],
237+
kernel_inputs.nodes(),
243238
)
244239

245240
if use_ck_gemm_template(layout, m, n, k):
246-
CKGemmTemplate.add_ck_gemm_choices(choices, layout, [mat1, mat2])
241+
CKGemmTemplate.add_ck_gemm_choices(choices, layout, kernel_inputs.nodes())
247242

248-
return autotune_select_algorithm("bmm", choices, [mat1, mat2], layout)
243+
return autotune_select_algorithm(name, choices, kernel_inputs.nodes(), layout)
249244

250245

251246
@L.register_lowering(aten.baddbmm)
252247
def tuned_baddbmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None):
248+
# TODO(coconutruben): integrate into MMKernelInputs when all callsites use that
253249
m, n, k, layout, mat1, mat2, inp = mm_args(mat1, mat2, inp, layout=layout)
254250

251+
# Create MMKernelInputs for BadDBMM at the top
252+
kernel_inputs = MMKernelInputs([inp, mat1, mat2])
253+
255254
# below is for getting an overview logging info of inductor mms
256255
batch_size = mat1.get_size()[0]
257256
counters["aten_mm_info"][f"aten.baddbmm_{batch_size}_{m}_{n}_{k}"] += 1
@@ -266,29 +265,26 @@ def tuned_baddbmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None):
266265
inp.get_dtype(),
267266
layout,
268267
)
269-
268+
name = "baddbmm"
270269
# options to tune from
271270
choices = (
272-
[aten_baddbmm.bind((inp, mat1, mat2), layout, alpha=alpha, beta=beta)]
271+
[aten_baddbmm.bind(kernel_inputs.nodes(), layout, alpha=alpha, beta=beta)]
273272
if use_aten_gemm_kernels()
274273
else []
275274
)
276275

277-
device_type = ir.get_device_type(mat1)
278-
bmm_configs = V.choices.get_base_mm_configs(device_type)
279-
280276
if use_triton_template(layout):
281-
for config in bmm_configs(
282-
m, n, k, **mm_config_kwargs(device_type, _is_large_block_for_cpu)
277+
for kwargs in V.choices.get_mm_configs(
278+
kernel_inputs, layout, bmm_template.name, name
283279
):
284280
bmm_template.maybe_append_choice(
285281
choices,
286-
input_nodes=(inp, mat1, mat2),
282+
input_nodes=kernel_inputs.nodes(),
287283
layout=layout,
288-
**mm_options(config, m, n, k, layout),
284+
**kwargs,
289285
prefix_args=1,
290286
epilogue_fn=addmm_epilogue(layout.dtype, alpha, beta),
291287
epilogue_fn_hash=str(["addmm_epilogue", layout.dtype, alpha, beta]),
292288
)
293289

294-
return autotune_select_algorithm("baddbmm", choices, [inp, mat1, mat2], layout)
290+
return autotune_select_algorithm(name, choices, kernel_inputs.nodes(), layout)

torch/_inductor/kernel/conv.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929
use_triton_template,
3030
)
3131
from ..virtualized import V
32-
from .mm_common import mm_config_kwargs
3332

3433

3534
if TYPE_CHECKING:
@@ -61,13 +60,6 @@ def conv3d_grid(n, c, d, h, w, meta, *, cdiv):
6160
)
6261

6362

64-
def _is_large_block_for_cpu(m, n, k):
65-
# Thresholds are experimentally determined to reduce Triton CPU compile times
66-
if m > 256 or n > 256 or k > 256:
67-
return True
68-
return m * n * k > 2**17
69-
70-
7163
LOOP_BODY_2D = """
7264
idx_x_h = i - PADDING_H + idx_y_h * STRIDE_H
7365
idx_x_w = j - PADDING_W + idx_y_w * STRIDE_W
@@ -603,7 +595,6 @@ def channels_last_conv():
603595
sympy_product([x.get_size()[0], *x.get_size()[2:]]),
604596
out_chan,
605597
in_chan,
606-
**mm_config_kwargs(device_type, _is_large_block_for_cpu),
607598
):
608599
if ndim == 2:
609600
conv2d_template.maybe_append_choice(

0 commit comments

Comments
 (0)