diff --git a/tritonbench/operators/fp8_gemm/fp8_gemm.py b/tritonbench/operators/fp8_gemm/fp8_gemm.py index a1a9ed1c7..fd27b48f6 100644 --- a/tritonbench/operators/fp8_gemm/fp8_gemm.py +++ b/tritonbench/operators/fp8_gemm/fp8_gemm.py @@ -189,6 +189,27 @@ def blackwell_persistent_tma_fp8_gemm(self, a, b, scale_a, scale_b): self.extra_args.scaling_rowwise, ) + @register_benchmark(enabled=True) + def blackwell_pt2_fp8_gemm(self, a, b, scale_a, scale_b): + torch._dynamo.reset() + with inductor_config.patch( + max_autotune=True, + max_autotune_gemm_backends="TRITON", + autotune_fallback_to_aten=False, + ): + f = lambda a, b: torch._scaled_mm( + a, + b, + scale_a, + scale_b, + use_fast_accum=True, + out_dtype=self._get_dtype() + ) + compiled = torch.compile(f, dynamic=False) + compiled(a, b) + + return lambda: compiled(a, b) + @register_benchmark() def triton_fp8_gemm(self, a, b, scale_a, scale_b): return lambda: tutorial_matmul(a, b)