From 69edf003a19fe0937069ede8f8f72a0eb559b256 Mon Sep 17 00:00:00 2001 From: Janani Sriram Date: Mon, 25 Aug 2025 14:54:13 -0700 Subject: [PATCH] Add amax as default per-row scaling factor for fp8_gemm benchmark (#341) Summary: Add `amax` (absolute maximum) as the default scaling factor for per-row scaling for fp8 GEMMs, as is used in practice. Reviewed By: xuzhao9 Test Plan: Imported from GitHub, without a `Test Plan:` line. Rollback Plan: Differential Revision: D80590746 Pulled By: jananisriram --- tritonbench/operators/fp8_gemm/fp8_gemm.py | 57 +++++++++++++++++----- 1 file changed, 44 insertions(+), 13 deletions(-) diff --git a/tritonbench/operators/fp8_gemm/fp8_gemm.py b/tritonbench/operators/fp8_gemm/fp8_gemm.py index fe1804ee0..e68cba2af 100644 --- a/tritonbench/operators/fp8_gemm/fp8_gemm.py +++ b/tritonbench/operators/fp8_gemm/fp8_gemm.py @@ -41,6 +41,8 @@ def parse_args(args): parser.add_argument("--m", type=int) parser.add_argument("--k", type=int) parser.add_argument("--n", type=int) + parser.add_argument("--per-tensor-scale-a", type=float, default=None) + parser.add_argument("--per-tensor-scale-b", type=float, default=None) return parser.parse_args(args) @@ -54,18 +56,53 @@ def __init__( super().__init__(tb_args, extra_args) self.extra_args = parse_args(extra_args) + def _get_dtype(self): + if self.extra_args.scaling_rowwise: + return torch.bfloat16 + else: + return torch.float16 + def get_input_iter(self): + def _get_scale_per_tensor( + x: torch.Tensor, custom_scale: float = None + ) -> torch.Tensor: + # For tensor-wise scaling, kernel requires a float32 scale tensor + if custom_scale: + return torch.tensor(custom_scale, dtype=torch.float32, device=x.device) + scale = torch.finfo(torch.float8_e4m3fn).max / x.abs().max() + return scale.to(torch.float32) + + def _get_scale_per_row( + x: torch.Tensor, transpose: bool = False + ) -> torch.Tensor: + if transpose: # scale_b.shape should be [1, N] + scale = ( + torch.finfo(torch.float8_e4m3fn).max + / x.abs().max(dim=0, keepdim=True).values + ) + else: # scale_a.shape should be [M, 1] + scale = ( + torch.finfo(torch.float8_e4m3fn).max + / x.abs().max(dim=1, keepdim=True).values + ) + return scale.to( + torch.float32 + ) # For row-wise scaling, kernel requires a float32 scale tensor + def args(m, n, k): a = torch.randn(m, k, device=self.device).to(torch.float16) b = torch.randn(k, n, device=self.device).to(torch.float16).T.contiguous().T if self.extra_args.scaling_rowwise: - M, N = a.shape[0], b.shape[1] - scale_a = torch.ones((M, 1), dtype=torch.float32, device=a.device) - scale_b = torch.ones((1, N), dtype=torch.float32, device=b.device) + scale_a = _get_scale_per_row(a) + scale_b = _get_scale_per_row(b, transpose=True) else: - scale_a = torch.tensor(1.0, device=a.device) - scale_b = torch.tensor(1.0, device=a.device) + scale_a = _get_scale_per_tensor( + a, custom_scale=self.extra_args.per_tensor_scale_a + ) + scale_b = _get_scale_per_tensor( + b, custom_scale=self.extra_args.per_tensor_scale_b + ) # Kernels expect dtype=float8_e4m3fn a = a.to(torch.float8_e4m3fn) @@ -103,16 +140,10 @@ def get_x_val(self, example_inputs) -> float: _, n = b.size() return (m, n, k) - def _get_out_dtype(self): - if self.extra_args.scaling_rowwise: - return torch.bfloat16 - else: - return torch.float16 - @register_benchmark(baseline=True) def torch_fp8_gemm(self, a, b, scale_a, scale_b): return lambda: torch._scaled_mm( - a, b, scale_a, scale_b, use_fast_accum=True, out_dtype=self._get_out_dtype() + a, b, scale_a, scale_b, use_fast_accum=True, out_dtype=self._get_dtype() ) @register_benchmark() @@ -129,7 +160,7 @@ def pt2_fp8_gemm(self, a, b, scale_a, scale_b) -> Callable: scale_a, scale_b, use_fast_accum=True, - out_dtype=self._get_out_dtype(), + out_dtype=self._get_dtype(), ) compiled = torch.compile(f, dynamic=False) compiled(a, b)