Skip to content

Commit e6bcbbe

Browse files
[Inductor] No longer throw error in bmm out_dtype lowering due to tem… (pytorch#166922)
[Inductor] No longer throw error in bmm out_dtype lowering due to template heuristics (pytorch#166457) Fixes pytorch#165892 Pull Request resolved: pytorch#166457 Approved by: https://github.com/coconutruben (cherry picked from commit c2e3cc7) Co-authored-by: PaulZhang12 <[email protected]>
1 parent 8f658d7 commit e6bcbbe

File tree

2 files changed

+26
-2
lines changed

2 files changed

+26
-2
lines changed

test/inductor/test_max_autotune.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1479,6 +1479,29 @@ def mm_transpose_relu(a, b):
14791479
# Check that contiguous transform was used
14801480
FileCheck().check("contiguous_mm").run(code[0])
14811481

1482+
@unittest.skipIf(config.cpp_wrapper, "out_dtype override not supported for AOTI")
1483+
@unittest.skipIf(TEST_WITH_ROCM, "out_dtype override only available on NVIDIA")
1484+
def test_bmm_out_dtype(self):
1485+
def f(a, b):
1486+
return torch.bmm(a, b, out_dtype=torch.float32)
1487+
1488+
a = torch.randn(2, 3, 4, device=GPU_TYPE, dtype=torch.float16)
1489+
b = torch.randn(2, 4, 5, device=GPU_TYPE, dtype=torch.float16)
1490+
with config.patch(
1491+
max_autotune=True,
1492+
max_autotune_gemm_backends="TRITON",
1493+
):
1494+
compiled_f = torch.compile(f)
1495+
with self.assertRaisesRegex(
1496+
torch._inductor.exc.InductorError,
1497+
r"LoweringException: NoValidChoicesError: No choices to select",
1498+
):
1499+
out, code = run_and_get_code(compiled_f, a, b)
1500+
1501+
compiled_f = torch.compile(f)
1502+
out, code = run_and_get_code(compiled_f, a, b)
1503+
FileCheck().check("extern_kernels.bmm_dtype").run(code[0])
1504+
14821505
def test_triton_template_generated_code_cache_key(self):
14831506
generate_and_load_args = len(
14841507
inspect.signature(

torch/_inductor/kernel/bmm.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -208,9 +208,10 @@ def may_require_contiguous(t, meta_t):
208208
)
209209
)
210210

211-
if use_triton_template(layout, check_max_autotune=False):
211+
if use_triton_template(layout, check_max_autotune=False) and (
212+
out_dtype is None or out_dtype == mat1.get_dtype()
213+
):
212214
# TODO: add out_dtype support for Triton Template
213-
assert out_dtype is None, "out_dtype is not supported for Triton"
214215

215216
choices.extend(
216217
V.choices.get_mm_configs(kernel_inputs, layout, [bmm_template], name)

0 commit comments

Comments
 (0)