Skip to content

Commit 04fb19a

Browse files
authored
Add an inductor generated _scale_mm to TritonBench
Differential Revision: D79263834 Pull Request resolved: #316
1 parent 73c6e33 commit 04fb19a

File tree

1 file changed

+20
-1
lines changed

1 file changed

+20
-1
lines changed

tritonbench/operators/fp8_gemm/fp8_gemm.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import argparse
22
import logging
33

4-
from typing import Any, List, Optional
4+
from typing import Any, Callable, List, Optional
55

66
import torch
7+
import torch._inductor.config as inductor_config
78
import triton
89

910
from tritonbench.utils.triton_op import (
@@ -90,6 +91,24 @@ def torch_fp8_gemm(self, a, b):
9091
a, b, scale_a, scale_b, use_fast_accum=True, out_dtype=torch.float16
9192
)
9293

94+
@register_benchmark()
95+
def pt2_fp8_gemm(self, a, b) -> Callable:
96+
torch._dynamo.reset()
97+
with inductor_config.patch(
98+
max_autotune=True,
99+
max_autotune_gemm_backends="TRITON",
100+
autotune_fallback_to_aten=False,
101+
):
102+
scale_a = torch.tensor(1.0, device=a.device)
103+
scale_b = torch.tensor(1.0, device=a.device)
104+
f = lambda a, b: torch._scaled_mm(
105+
a, b, scale_a, scale_b, use_fast_accum=True, out_dtype=torch.float16
106+
)
107+
compiled = torch.compile(f, dynamic=False)
108+
compiled(a, b)
109+
110+
return lambda: compiled(a, b)
111+
93112
@register_benchmark()
94113
def triton_fp8_gemm(self, a, b):
95114
return lambda: tutorial_matmul(a, b)

0 commit comments

Comments
 (0)