diff --git a/tritonbench/operators/fp8_gemm/fp8_gemm.py b/tritonbench/operators/fp8_gemm/fp8_gemm.py index 44953f06..98138751 100644 --- a/tritonbench/operators/fp8_gemm/fp8_gemm.py +++ b/tritonbench/operators/fp8_gemm/fp8_gemm.py @@ -1,4 +1,5 @@ import argparse + import logging from typing import Any, Callable, List, Optional @@ -7,6 +8,8 @@ import torch._inductor.config as inductor_config import triton +from torch._inductor.kernel.mm import scaling_pairs, ScalingType + from tritonbench.operators.fp8_gemm.persistent import blackwell_persistent_tma from tritonbench.utils.env_utils import get_nvidia_gpu_model, is_cuda @@ -42,11 +45,15 @@ HAS_TMA = False logger.warning(f"Failed to import TMA: {e}") +HAS_CUDA_129 = ( + torch.cuda.is_available() and torch.version.cuda and torch.version.cuda >= "12.9" +) + def parse_args(args): parser = argparse.ArgumentParser(description="TritonBench fp8_gemm") parser.add_argument("--llama", action="store_true") - parser.add_argument("--scaling_rowwise", action="store_true") + parser.add_argument("--scaling-pair", type=str, default="TensorWise,TensorWise") parser.add_argument("--m", type=int) parser.add_argument("--k", type=int) parser.add_argument("--n", type=int) @@ -55,6 +62,86 @@ def parse_args(args): return parser.parse_args(args) +def get_scaling_recipe(scaling_recipe: str) -> int: + if scaling_recipe == "TensorWise": + return ScalingType.TensorWise + elif scaling_recipe == "RowWise": + return ScalingType.RowWise + elif scaling_recipe == "BlockWise1x128": + return ScalingType.BlockWise1x128 + elif scaling_recipe == "BlockWise128x128": + return ScalingType.BlockWise128x128 + else: + raise ValueError(f"Invalid scaling recipe: {scaling_recipe}") + + +def get_scale( + x: torch.Tensor, + scaling_recipe: ScalingType, + transpose: bool = False, + custom_scale: float = None, +) -> (torch.Tensor, torch.Tensor): + def _get_scale_per_tensor( + x: torch.Tensor, custom_scale: float = None + ) -> (torch.Tensor, torch.Tensor): + # For tensor-wise scaling, kernel requires a float32 scale tensor + if custom_scale: + return torch.tensor(custom_scale, dtype=torch.float32, device=x.device) + scale = (torch.finfo(torch.float8_e4m3fn).max / x.abs().max()).reciprocal() + x *= scale + return x, scale.to(torch.float32) + + def _get_scale_per_row( + x: torch.Tensor, transpose: bool = False + ) -> (torch.Tensor, torch.Tensor): + if transpose: # scale_b.shape should be [1, N] + scale = ( + torch.finfo(torch.float8_e4m3fn).max + / x.abs().max(dim=0, keepdim=True).values + ).reciprocal() + else: # scale_a.shape should be [M, 1] + scale = ( + torch.finfo(torch.float8_e4m3fn).max + / x.abs().max(dim=1, keepdim=True).values + ).reciprocal() + x = x.mul(scale) + return x, scale.to( + torch.float32 + ) # For row-wise scaling, kernel requires a float32 scale tensor + + def _get_scale_per_block( + x: torch.Tensor, block_outer: int, block_inner: int + ) -> (torch.Tensor, torch.Tensor): + x = x.unflatten(1, (-1, block_inner)).unflatten(0, (-1, block_outer)) + amax = x.abs().amax(dim=[1, 3], keepdim=True).float() + scale = ( + torch.finfo(torch.float8_e4m3fn).max / amax + ).reciprocal() # keeps scale small enough such that scaling doesn't cause inf values + x = ( + x.mul(scale).flatten(2, 3).flatten(0, 1) + ) # scale input up to dynamic range of float8_e4m3fn + scale = scale.flatten(2, 3).flatten(0, 1) + + if block_outer == 1 and block_inner == 128: + scale = ( + scale.t().contiguous().t() + ) # 1x128 blocks need scales to be outer-dim-major + + return x, scale.to(torch.float32) + + match scaling_recipe: + case ScalingType.TensorWise: + return _get_scale_per_tensor(x, custom_scale=custom_scale) + case ScalingType.RowWise: + return _get_scale_per_row(x, transpose=transpose) + case ScalingType.BlockWise1x128: + return _get_scale_per_block(x, 1, 128) + case ScalingType.BlockWise128x128: + return _get_scale_per_block(x, 128, 128) + case _: + raise AssertionError(f"Unsupported scaling type {scaling_recipe}") + + class Operator(BenchmarkOperator): DEFAULT_METRICS = ["tflops", "gbps", "latency"] DEFAULT_PRECISION = "fp8" @@ -66,53 +153,52 @@ def __init__( super().__init__(tb_args, extra_args) self.extra_args = parse_args(extra_args) + scaling_recipe_a, scaling_recipe_b = self.extra_args.scaling_pair.split(",") + if (scaling_recipe_a, scaling_recipe_b) not in [ + (a.name, b.name) for a, b in scaling_pairs + ]: + raise ValueError( + f"Invalid scaling pair: {scaling_recipe_a}, {scaling_recipe_b}. See torch/_inductor/kernel/mm.py::scaling_pairs for valid pairs." + ) + self.scaling_recipe_a = get_scaling_recipe(scaling_recipe_a) + self.scaling_recipe_b = get_scaling_recipe(scaling_recipe_b) + + blockwise_scaling_types = [ + ScalingType.BlockWise1x128, + ScalingType.BlockWise128x128, + ] + self.contains_blockwise_scaling = ( + self.scaling_recipe_a in blockwise_scaling_types + or self.scaling_recipe_b in blockwise_scaling_types + ) + + self.use_fast_accum = ( + False if self.contains_blockwise_scaling else True + ) # BlockWise scaled_gemm does not support use_fast_accum=True + def _get_dtype(self): - if self.extra_args.scaling_rowwise: - return torch.bfloat16 - else: + if ( + self.scaling_recipe_a == ScalingType.TensorWise + and self.scaling_recipe_b == ScalingType.TensorWise + ): return torch.float16 + return torch.bfloat16 def get_input_iter(self): - def _get_scale_per_tensor( - x: torch.Tensor, custom_scale: float = None - ) -> torch.Tensor: - # For tensor-wise scaling, kernel requires a float32 scale tensor - if custom_scale: - return torch.tensor(custom_scale, dtype=torch.float32, device=x.device) - scale = torch.finfo(torch.float8_e4m3fn).max / x.abs().max() - return scale.to(torch.float32) - - def _get_scale_per_row( - x: torch.Tensor, transpose: bool = False - ) -> torch.Tensor: - if transpose: # scale_b.shape should be [1, N] - scale = ( - torch.finfo(torch.float8_e4m3fn).max - / x.abs().max(dim=0, keepdim=True).values - ) - else: # scale_a.shape should be [M, 1] - scale = ( - torch.finfo(torch.float8_e4m3fn).max - / x.abs().max(dim=1, keepdim=True).values - ) - return scale.to( - torch.float32 - ) # For row-wise scaling, kernel requires a float32 scale tensor - def args(m, n, k): a = torch.randn(m, k, device=self.device).to(self._get_dtype()) b = torch.randn(n, k, device=self.device).to(self._get_dtype()) - if self.extra_args.scaling_rowwise: - scale_a = _get_scale_per_row(a) - scale_b = _get_scale_per_row(b) - else: - scale_a = _get_scale_per_tensor( - a, custom_scale=self.extra_args.per_tensor_scale_a - ) - scale_b = _get_scale_per_tensor( - b, custom_scale=self.extra_args.per_tensor_scale_b - ) + a, scale_a = get_scale( + a, + self.scaling_recipe_a, + custom_scale=self.extra_args.per_tensor_scale_a, + ) + b, scale_b = get_scale( + b, + self.scaling_recipe_b, + custom_scale=self.extra_args.per_tensor_scale_b, + ) # Kernels expect dtype=float8_e4m3fn a = a.to(torch.float8_e4m3fn) @@ -152,12 +238,16 @@ def get_x_val(self, example_inputs) -> float: @register_benchmark(baseline=True) def torch_fp8_gemm(self, a, b, scale_a, scale_b): + assert ( + not self.contains_blockwise_scaling or HAS_CUDA_129 + ), "BlockWise scaling variants for scaled_gemm require CUDA 12.9+" + return lambda: torch._scaled_mm( a, b.t(), scale_a, scale_b.t(), - use_fast_accum=True, + use_fast_accum=self.use_fast_accum, out_dtype=self._get_dtype(), ) @@ -174,7 +264,7 @@ def pt2_fp8_gemm(self, a, b, scale_a, scale_b) -> Callable: b.t(), scale_a, scale_b.t(), - use_fast_accum=True, + use_fast_accum=self.use_fast_accum, out_dtype=self._get_dtype(), ) compiled = torch.compile(f, dynamic=False) @@ -186,13 +276,21 @@ def pt2_fp8_gemm(self, a, b, scale_a, scale_b) -> Callable: @register_benchmark(enabled=True) def blackwell_persistent_tma_fp8_gemm(self, a, b, scale_a, scale_b): + if self.scaling_recipe_a == self.scaling_recipe_b == ScalingType.TensorWise: + scaling_recipe_int = 0 + elif self.scaling_recipe_a == self.scaling_recipe_b == ScalingType.RowWise: + scaling_recipe_int = 1 + else: + raise ValueError( + f"Invalid scaling pair: {self.scaling_recipe_a}, {self.scaling_recipe_b} for blackwell_persistent_tma_fp8_gemm." + ) return lambda: blackwell_persistent_tma( a, b, scale_a, scale_b, self._get_dtype(), - self.extra_args.scaling_rowwise, + scaling_recipe_int, ) @register_benchmark(enabled=True) diff --git a/tritonbench/operators/fp8_gemm/persistent.py b/tritonbench/operators/fp8_gemm/persistent.py index 5ce97117..9c44e5a6 100644 --- a/tritonbench/operators/fp8_gemm/persistent.py +++ b/tritonbench/operators/fp8_gemm/persistent.py @@ -1,10 +1,13 @@ from functools import lru_cache + from typing import Optional import torch import triton import triton.language as tl +from torch._inductor.kernel.mm import ScalingType + from tritonbench.utils.env_utils import is_cuda from tritonbench.utils.triton_utils import has_experimental_descriptor @@ -410,9 +413,7 @@ def matmul_tma_persistent(a, b, c, desc_a, desc_b, desc_c): # - 1 warp = 32 threads, so each thread block requires 128 / 32 = 4 warps -def blackwell_persistent_tma( - a, b, scale_a_ptr, scale_b_ptr, acc_dtype, scaling_rowwise -): +def blackwell_persistent_tma(a, b, scale_a_ptr, scale_b_ptr, acc_dtype, scaling_mode): configs = matmul_configs_blackwell() # Check constraints. @@ -471,7 +472,7 @@ def alloc_fn(size: int, align: int, stream: Optional[int]): NUM_SMS=NUM_SMS, # num_stages=configs[shape_dtype]["num_stages"], # num_warps=configs[shape_dtype]["num_warps"], # - SCALING_ROWWISE=scaling_rowwise, + SCALING_MODE=scaling_mode, # WARP_SPECIALIZE=configs[shape_dtype]["WARP_SPECIALIZE"], # EPILOGUE_SUBTILE=configs[shape_dtype]["EPILOGUE_SUBTILE"], # ) @@ -504,7 +505,7 @@ def blackwell_persistent_tma_kernel( GROUP_SIZE_M: tl.constexpr, # ACC_TYPE: tl.constexpr, NUM_SMS: tl.constexpr, - SCALING_ROWWISE: tl.constexpr, # + SCALING_MODE: tl.constexpr, # WARP_SPECIALIZE: tl.constexpr, EPILOGUE_SUBTILE: tl.constexpr, ): # @@ -538,7 +539,7 @@ def blackwell_persistent_tma_kernel( tile_id_c = start_pid - NUM_SMS num_pid_in_group = GROUP_SIZE_M * num_pid_n - if SCALING_ROWWISE: + if SCALING_MODE == ScalingType.RowWise: # For row-wise scaling, we'll use the pointers as-is scale_a = scale_a_ptr scale_b = scale_b_ptr @@ -563,7 +564,7 @@ def blackwell_persistent_tma_kernel( b_block = b_desc.load([offs_bn, offs_k]) accumulator = tl.dot(a_block, b_block.T, accumulator, out_dtype=tl.float32) - if SCALING_ROWWISE: + if SCALING_MODE == ScalingType.RowWise: offs_scale_m = offs_am + tl.arange(0, BLOCK_SIZE_M) offs_scale_n = offs_bn + tl.arange(0, BLOCK_SIZE_N)