From 7d670630e29b9b761738350f034d9915d0e9b1fa Mon Sep 17 00:00:00 2001 From: Janani Sriram Date: Thu, 21 Aug 2025 11:54:48 -0700 Subject: [PATCH] Move scaling logic to input generation (#338) Summary: Move scaling logic for FP8 benchmarks to `get_input_iter()`. This diff aligns our fp8_gemm benchmarking suite with real-world practices: input tensors are of high precision types (`bfloat16`, `float16`), scales are computed on the high-precision input tensors, and input tensors are then casted to a lower precision (`float8_e4m3fn`). This diff also circumvents performing unsupported operations, like `torch.max` and `torch.abs`, on low-precision data types. Pull Request resolved: https://github.com/meta-pytorch/tritonbench/pull/338 Test Plan: Imported from GitHub, without a `Test Plan:` line. Rollback Plan: Reviewed By: xuzhao9 Differential Revision: D80571223 Pulled By: jananisriram --- tritonbench/operators/fp8_gemm/fp8_gemm.py | 60 +++++++++++----------- 1 file changed, 30 insertions(+), 30 deletions(-) diff --git a/tritonbench/operators/fp8_gemm/fp8_gemm.py b/tritonbench/operators/fp8_gemm/fp8_gemm.py index 1a617f3f8..dcfbbd6e5 100644 --- a/tritonbench/operators/fp8_gemm/fp8_gemm.py +++ b/tritonbench/operators/fp8_gemm/fp8_gemm.py @@ -56,14 +56,27 @@ def __init__( def get_input_iter(self): def args(m, n, k): - a = torch.randn(m, k, device=self.device).to(torch.float8_e4m3fn) + a = torch.randn(m, k, device=self.device).to(torch.float16) b = ( torch.randn(k, n, device=self.device) - .to(torch.float8_e4m3fn) + .to(torch.float16) .T.contiguous() .T ) - return (a, b) + + if self.extra_args.scaling_rowwise: + M, N = a.shape[0], b.shape[1] + scale_a = torch.ones((M, 1), dtype=torch.float32, device=a.device) + scale_b = torch.ones((1, N), dtype=torch.float32, device=b.device) + else: + scale_a = torch.tensor(1.0, device=a.device) + scale_b = torch.tensor(1.0, device=a.device) + + # Kernels expect dtype=float8_e4m3fn + a = a.to(torch.float8_e4m3fn) + b = b.to(torch.float8_e4m3fn) + + return (a, b, scale_a, scale_b) if ( hasattr(self, "external_shapes") and self.external_shapes @@ -90,46 +103,33 @@ def args(m, n, k): yield args(m, n, k) def get_x_val(self, example_inputs) -> float: - a, b = example_inputs + a, b, _, _ = example_inputs m, k = a.size() _, n = b.size() return (m, n, k) - @register_benchmark(baseline=True) - def torch_fp8_gemm(self, a, b): + def _get_out_dtype(self): if self.extra_args.scaling_rowwise: - M, N = a.shape[0], b.shape[1] - scale_a = torch.ones((M, 1), dtype=torch.float32, device=a.device) - scale_b = torch.ones((1, N), dtype=torch.float32, device=b.device) - out_dtype = torch.bfloat16 + return torch.bfloat16 else: - scale_a = torch.tensor(1.0, device=a.device) - scale_b = torch.tensor(1.0, device=a.device) - out_dtype = torch.float16 + return torch.float16 + @register_benchmark(baseline=True) + def torch_fp8_gemm(self, a, b, scale_a, scale_b): return lambda: torch._scaled_mm( - a, b, scale_a, scale_b, use_fast_accum=True, out_dtype=out_dtype + a, b, scale_a, scale_b, use_fast_accum=True, out_dtype=self._get_out_dtype() ) @register_benchmark() - def pt2_fp8_gemm(self, a, b) -> Callable: + def pt2_fp8_gemm(self, a, b, scale_a, scale_b) -> Callable: torch._dynamo.reset() with inductor_config.patch( max_autotune=True, max_autotune_gemm_backends="TRITON", autotune_fallback_to_aten=False, ): - if self.extra_args.scaling_rowwise: - M, N = a.shape[0], b.shape[1] - scale_a = torch.ones((M, 1), dtype=torch.float32, device=a.device) - scale_b = torch.ones((1, N), dtype=torch.float32, device=b.device) - out_dtype = torch.bfloat16 - else: - scale_a = torch.tensor(1.0, device=a.device) - scale_b = torch.tensor(1.0, device=b.device) - out_dtype = torch.float16 f = lambda a, b: torch._scaled_mm( - a, b, scale_a, scale_b, use_fast_accum=True, out_dtype=out_dtype + a, b, scale_a, scale_b, use_fast_accum=True, out_dtype=self._get_out_dtype() ) compiled = torch.compile(f, dynamic=False) compiled(a, b) @@ -137,15 +137,15 @@ def pt2_fp8_gemm(self, a, b) -> Callable: return lambda: compiled(a, b) @register_benchmark() - def triton_fp8_gemm(self, a, b): + def triton_fp8_gemm(self, a, b, scale_a, scale_b): return lambda: tutorial_matmul(a, b) @register_benchmark(enabled=HAS_TMA) - def triton_persistent_fp8_gemm(self, a, b): + def triton_persistent_fp8_gemm(self, a, b, scale_a, scale_b): return lambda: matmul_persistent(a, b) @register_benchmark(enabled=HAS_TMA) - def triton_tma_persistent_fp8_gemm(self, a, b): + def triton_tma_persistent_fp8_gemm(self, a, b, scale_a, scale_b): b = b.T.contiguous() c, desc_a, desc_b, desc_c = allocate_matmul_tma(a, b) return lambda: matmul_tma_persistent(a, b, c, desc_a, desc_b, desc_c) @@ -155,7 +155,7 @@ def gbps(self, fn, example_inputs: Any, metrics: BenchmarkOperatorMetrics) -> fl def nbytes(t): return t.numel() * t.element_size() - a, b = example_inputs + a, b, _, _ = example_inputs c = fn() c = c[0] if isinstance(c, tuple) else c @@ -168,7 +168,7 @@ def nbytes(t): def flops( self, fn_name: str, example_inputs: Any, metrics: BenchmarkOperatorMetrics ) -> float: - a, b = example_inputs + a, b, _, _ = example_inputs m, k = a.size() _, n = b.size() flops = 2 * m * n * k