Skip to content

Commit c2e3cc7

Browse files
PaulZhang12pytorchmergebot
authored andcommitted
[Inductor] No longer throw error in bmm out_dtype lowering due to template heuristics (pytorch#166457)
Fixes pytorch#165892 Pull Request resolved: pytorch#166457 Approved by: https://github.com/coconutruben
1 parent 5849eea commit c2e3cc7

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

torch/_inductor/kernel/bmm.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -239,9 +239,10 @@ def _to_dtype(x):
239239
templates_to_use.append(aten_handler)
240240
kwarg_overrides[aten_handler.uid] = aten_extra_kwargs
241241

242-
if use_triton_template(layout, check_max_autotune=False):
242+
if use_triton_template(layout, check_max_autotune=False) and (
243+
out_dtype is None or out_dtype == mat1.get_dtype()
244+
):
243245
# TODO: add out_dtype support for Triton Template
244-
assert out_dtype is None, "out_dtype is not supported for Triton"
245246
templates_to_use.append(bmm_template)
246247

247248
# Single unified call for all templates

0 commit comments

Comments
 (0)