diff --git a/tritonbench/operators/fp8_gemm/fp8_gemm.py b/tritonbench/operators/fp8_gemm/fp8_gemm.py index 1a617f3f8..dcfbbd6e5 100644 --- a/tritonbench/operators/fp8_gemm/fp8_gemm.py +++ b/tritonbench/operators/fp8_gemm/fp8_gemm.py @@ -56,14 +56,27 @@ def __init__( def get_input_iter(self): def args(m, n, k): - a = torch.randn(m, k, device=self.device).to(torch.float8_e4m3fn) + a = torch.randn(m, k, device=self.device).to(torch.float16) b = ( torch.randn(k, n, device=self.device) - .to(torch.float8_e4m3fn) + .to(torch.float16) .T.contiguous() .T ) - return (a, b) + + if self.extra_args.scaling_rowwise: + M, N = a.shape[0], b.shape[1] + scale_a = torch.ones((M, 1), dtype=torch.float32, device=a.device) + scale_b = torch.ones((1, N), dtype=torch.float32, device=b.device) + else: + scale_a = torch.tensor(1.0, device=a.device) + scale_b = torch.tensor(1.0, device=a.device) + + # Kernels expect dtype=float8_e4m3fn + a = a.to(torch.float8_e4m3fn) + b = b.to(torch.float8_e4m3fn) + + return (a, b, scale_a, scale_b) if ( hasattr(self, "external_shapes") and self.external_shapes @@ -90,46 +103,33 @@ def args(m, n, k): yield args(m, n, k) def get_x_val(self, example_inputs) -> float: - a, b = example_inputs + a, b, _, _ = example_inputs m, k = a.size() _, n = b.size() return (m, n, k) - @register_benchmark(baseline=True) - def torch_fp8_gemm(self, a, b): + def _get_out_dtype(self): if self.extra_args.scaling_rowwise: - M, N = a.shape[0], b.shape[1] - scale_a = torch.ones((M, 1), dtype=torch.float32, device=a.device) - scale_b = torch.ones((1, N), dtype=torch.float32, device=b.device) - out_dtype = torch.bfloat16 + return torch.bfloat16 else: - scale_a = torch.tensor(1.0, device=a.device) - scale_b = torch.tensor(1.0, device=a.device) - out_dtype = torch.float16 + return torch.float16 + @register_benchmark(baseline=True) + def torch_fp8_gemm(self, a, b, scale_a, scale_b): return lambda: torch._scaled_mm( - a, b, scale_a, scale_b, use_fast_accum=True, out_dtype=out_dtype + a, b, scale_a, scale_b, use_fast_accum=True, out_dtype=self._get_out_dtype() ) @register_benchmark() - def pt2_fp8_gemm(self, a, b) -> Callable: + def pt2_fp8_gemm(self, a, b, scale_a, scale_b) -> Callable: torch._dynamo.reset() with inductor_config.patch( max_autotune=True, max_autotune_gemm_backends="TRITON", autotune_fallback_to_aten=False, ): - if self.extra_args.scaling_rowwise: - M, N = a.shape[0], b.shape[1] - scale_a = torch.ones((M, 1), dtype=torch.float32, device=a.device) - scale_b = torch.ones((1, N), dtype=torch.float32, device=b.device) - out_dtype = torch.bfloat16 - else: - scale_a = torch.tensor(1.0, device=a.device) - scale_b = torch.tensor(1.0, device=b.device) - out_dtype = torch.float16 f = lambda a, b: torch._scaled_mm( - a, b, scale_a, scale_b, use_fast_accum=True, out_dtype=out_dtype + a, b, scale_a, scale_b, use_fast_accum=True, out_dtype=self._get_out_dtype() ) compiled = torch.compile(f, dynamic=False) compiled(a, b) @@ -137,15 +137,15 @@ def pt2_fp8_gemm(self, a, b) -> Callable: return lambda: compiled(a, b) @register_benchmark() - def triton_fp8_gemm(self, a, b): + def triton_fp8_gemm(self, a, b, scale_a, scale_b): return lambda: tutorial_matmul(a, b) @register_benchmark(enabled=HAS_TMA) - def triton_persistent_fp8_gemm(self, a, b): + def triton_persistent_fp8_gemm(self, a, b, scale_a, scale_b): return lambda: matmul_persistent(a, b) @register_benchmark(enabled=HAS_TMA) - def triton_tma_persistent_fp8_gemm(self, a, b): + def triton_tma_persistent_fp8_gemm(self, a, b, scale_a, scale_b): b = b.T.contiguous() c, desc_a, desc_b, desc_c = allocate_matmul_tma(a, b) return lambda: matmul_tma_persistent(a, b, c, desc_a, desc_b, desc_c) @@ -155,7 +155,7 @@ def gbps(self, fn, example_inputs: Any, metrics: BenchmarkOperatorMetrics) -> fl def nbytes(t): return t.numel() * t.element_size() - a, b = example_inputs + a, b, _, _ = example_inputs c = fn() c = c[0] if isinstance(c, tuple) else c @@ -168,7 +168,7 @@ def nbytes(t): def flops( self, fn_name: str, example_inputs: Any, metrics: BenchmarkOperatorMetrics ) -> float: - a, b = example_inputs + a, b, _, _ = example_inputs m, k = a.size() _, n = b.size() flops = 2 * m * n * k