Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
176 changes: 133 additions & 43 deletions tritonbench/operators/fp8_gemm/fp8_gemm.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import argparse

import logging

from typing import Any, Callable, List, Optional
Expand All @@ -7,6 +8,8 @@
import torch._inductor.config as inductor_config
import triton

from torch._inductor.kernel.mm import scaling_pairs, ScalingType

from tritonbench.operators.fp8_gemm.persistent import blackwell_persistent_tma
from tritonbench.utils.env_utils import get_nvidia_gpu_model, is_cuda

Expand Down Expand Up @@ -42,11 +45,15 @@
HAS_TMA = False
logger.warning(f"Failed to import TMA: {e}")

HAS_CUDA_129 = (
torch.cuda.is_available() and torch.version.cuda and torch.version.cuda >= "12.9"
)


def parse_args(args):
parser = argparse.ArgumentParser(description="TritonBench fp8_gemm")
parser.add_argument("--llama", action="store_true")
parser.add_argument("--scaling_rowwise", action="store_true")
parser.add_argument("--scaling-pair", type=str, default="TensorWise,TensorWise")
parser.add_argument("--m", type=int)
parser.add_argument("--k", type=int)
parser.add_argument("--n", type=int)
Expand All @@ -55,6 +62,86 @@ def parse_args(args):
return parser.parse_args(args)


def get_scaling_recipe_int(scaling_recipe: str) -> int:
if scaling_recipe == "TensorWise":
return ScalingType.TensorWise
elif scaling_recipe == "RowWise":
return ScalingType.RowWise
elif scaling_recipe == "BlockWise1x128":
return ScalingType.BlockWise1x128
elif scaling_recipe == "BlockWise128x128":
return ScalingType.BlockWise128x128
else:
raise ValueError(f"Invalid scaling recipe: {scaling_recipe}")


def get_scale(
x: torch.Tensor,
scaling_recipe_int: int,
transpose: bool = False,
custom_scale: float = None,
) -> (torch.Tensor, torch.Tensor):
def _get_scale_per_tensor(
x: torch.Tensor, custom_scale: float = None
) -> (torch.Tensor, torch.Tensor):
# For tensor-wise scaling, kernel requires a float32 scale tensor
if custom_scale:
return torch.tensor(custom_scale, dtype=torch.float32, device=x.device)
scale = (torch.finfo(torch.float8_e4m3fn).max / x.abs().max()).reciprocal()
x *= scale
return x, scale.to(torch.float32)

def _get_scale_per_row(
x: torch.Tensor, transpose: bool = False
) -> (torch.Tensor, torch.Tensor):
if transpose: # scale_b.shape should be [1, N]
scale = (
torch.finfo(torch.float8_e4m3fn).max
/ x.abs().max(dim=0, keepdim=True).values
).reciprocal()
else: # scale_a.shape should be [M, 1]
scale = (
torch.finfo(torch.float8_e4m3fn).max
/ x.abs().max(dim=1, keepdim=True).values
).reciprocal()
x = x.mul(scale)
return x, scale.to(
torch.float32
) # For row-wise scaling, kernel requires a float32 scale tensor

def _get_scale_per_block(
x: torch.Tensor, block_outer: int, block_inner: int
) -> (torch.Tensor, torch.Tensor):
x = x.unflatten(1, (-1, block_inner)).unflatten(0, (-1, block_outer))
amax = x.abs().amax(dim=[1, 3], keepdim=True).float()
scale = (
torch.finfo(torch.float8_e4m3fn).max / amax
).reciprocal() # keeps scale small enough such that scaling doesn't cause inf values
x = (
x.mul(scale).flatten(2, 3).flatten(0, 1)
) # scale input up to dynamic range of float8_e4m3fn
scale = scale.flatten(2, 3).flatten(0, 1)

if block_outer == 1 and block_inner == 128:
scale = (
scale.t().contiguous().t()
) # 1x128 blocks need scales to be outer-dim-major

return x, scale.to(torch.float32)

match scaling_recipe_int:
case ScalingType.TensorWise:
return _get_scale_per_tensor(x, custom_scale=custom_scale)
case ScalingType.RowWise:
return _get_scale_per_row(x, transpose=transpose)
case ScalingType.BlockWise1x128:
return _get_scale_per_block(x, 1, 128)
case ScalingType.BlockWise128x128:
return _get_scale_per_block(x, 128, 128)
case _:
raise AssertionError(f"Unsupported scaling type {scaling_recipe_int}")


class Operator(BenchmarkOperator):
DEFAULT_METRICS = ["tflops", "gbps", "latency"]
DEFAULT_PRECISION = "fp8"
Expand All @@ -66,53 +153,52 @@ def __init__(
super().__init__(tb_args, extra_args)
self.extra_args = parse_args(extra_args)

scaling_recipe_a, scaling_recipe_b = self.extra_args.scaling_pair.split(",")
if (scaling_recipe_a, scaling_recipe_b) not in [
(a.name, b.name) for a, b in scaling_pairs
]:
raise ValueError(
f"Invalid scaling pair: {scaling_recipe_a}, {scaling_recipe_b}. See torch/_inductor/kernel/mm.py::scaling_pairs for valid pairs."
)
self.scaling_recipe_a_int = get_scaling_recipe_int(scaling_recipe_a).value
self.scaling_recipe_b_int = get_scaling_recipe_int(scaling_recipe_b).value

blockwise_scaling_types = [
ScalingType.BlockWise1x128,
ScalingType.BlockWise128x128,
]
self.contains_blockwise_scaling = (
self.scaling_recipe_a_int in blockwise_scaling_types
or self.scaling_recipe_b_int in blockwise_scaling_types
)

self.use_fast_accum = (
False if self.contains_blockwise_scaling else True
) # BlockWise scaled_gemm does not support use_fast_accum=True

def _get_dtype(self):
if self.extra_args.scaling_rowwise:
return torch.bfloat16
else:
if (
self.scaling_recipe_a_int == ScalingType.TensorWise
and self.scaling_recipe_b_int == ScalingType.TensorWise
):
return torch.float16
return torch.bfloat16

def get_input_iter(self):
def _get_scale_per_tensor(
x: torch.Tensor, custom_scale: float = None
) -> torch.Tensor:
# For tensor-wise scaling, kernel requires a float32 scale tensor
if custom_scale:
return torch.tensor(custom_scale, dtype=torch.float32, device=x.device)
scale = torch.finfo(torch.float8_e4m3fn).max / x.abs().max()
return scale.to(torch.float32)

def _get_scale_per_row(
x: torch.Tensor, transpose: bool = False
) -> torch.Tensor:
if transpose: # scale_b.shape should be [1, N]
scale = (
torch.finfo(torch.float8_e4m3fn).max
/ x.abs().max(dim=0, keepdim=True).values
)
else: # scale_a.shape should be [M, 1]
scale = (
torch.finfo(torch.float8_e4m3fn).max
/ x.abs().max(dim=1, keepdim=True).values
)
return scale.to(
torch.float32
) # For row-wise scaling, kernel requires a float32 scale tensor

def args(m, n, k):
a = torch.randn(m, k, device=self.device).to(self._get_dtype())
b = torch.randn(n, k, device=self.device).to(self._get_dtype())

if self.extra_args.scaling_rowwise:
scale_a = _get_scale_per_row(a)
scale_b = _get_scale_per_row(b)
else:
scale_a = _get_scale_per_tensor(
a, custom_scale=self.extra_args.per_tensor_scale_a
)
scale_b = _get_scale_per_tensor(
b, custom_scale=self.extra_args.per_tensor_scale_b
)
a, scale_a = get_scale(
a,
self.scaling_recipe_a_int,
custom_scale=self.extra_args.per_tensor_scale_a,
)
b, scale_b = get_scale(
b,
self.scaling_recipe_b_int,
custom_scale=self.extra_args.per_tensor_scale_b,
)

# Kernels expect dtype=float8_e4m3fn
a = a.to(torch.float8_e4m3fn)
Expand Down Expand Up @@ -152,12 +238,16 @@ def get_x_val(self, example_inputs) -> float:

@register_benchmark(baseline=True)
def torch_fp8_gemm(self, a, b, scale_a, scale_b):
assert (
not self.contains_blockwise_scaling or HAS_CUDA_129
), "BlockWise scaling variants for scaled_gemm require CUDA 12.9+"

return lambda: torch._scaled_mm(
a,
b.t(),
scale_a,
scale_b.t(),
use_fast_accum=True,
use_fast_accum=self.use_fast_accum,
out_dtype=self._get_dtype(),
)

Expand All @@ -174,7 +264,7 @@ def pt2_fp8_gemm(self, a, b, scale_a, scale_b) -> Callable:
b.t(),
scale_a,
scale_b.t(),
use_fast_accum=True,
use_fast_accum=self.use_fast_accum,
out_dtype=self._get_dtype(),
)
compiled = torch.compile(f, dynamic=False)
Expand All @@ -192,7 +282,7 @@ def blackwell_persistent_tma_fp8_gemm(self, a, b, scale_a, scale_b):
scale_a,
scale_b,
self._get_dtype(),
self.extra_args.scaling_rowwise,
0 if self.scaling_recipe_a_int == self.scaling_recipe_b_int == 0 else 1,
)

@register_benchmark(enabled=True)
Expand Down
15 changes: 8 additions & 7 deletions tritonbench/operators/fp8_gemm/persistent.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from functools import lru_cache

from typing import Optional

import torch
import triton
import triton.language as tl

from torch._inductor.kernel.mm import ScalingType

from tritonbench.utils.env_utils import is_cuda
from tritonbench.utils.triton_utils import has_experimental_descriptor

Expand Down Expand Up @@ -410,9 +413,7 @@ def matmul_tma_persistent(a, b, c, desc_a, desc_b, desc_c):
# - 1 warp = 32 threads, so each thread block requires 128 / 32 = 4 warps


def blackwell_persistent_tma(
a, b, scale_a_ptr, scale_b_ptr, acc_dtype, scaling_rowwise
):
def blackwell_persistent_tma(a, b, scale_a_ptr, scale_b_ptr, acc_dtype, scaling_mode):
configs = matmul_configs_blackwell()

# Check constraints.
Expand Down Expand Up @@ -471,7 +472,7 @@ def alloc_fn(size: int, align: int, stream: Optional[int]):
NUM_SMS=NUM_SMS, #
num_stages=configs[shape_dtype]["num_stages"], #
num_warps=configs[shape_dtype]["num_warps"], #
SCALING_ROWWISE=scaling_rowwise,
SCALING_MODE=scaling_mode, #
WARP_SPECIALIZE=configs[shape_dtype]["WARP_SPECIALIZE"], #
EPILOGUE_SUBTILE=configs[shape_dtype]["EPILOGUE_SUBTILE"], #
)
Expand Down Expand Up @@ -504,7 +505,7 @@ def blackwell_persistent_tma_kernel(
GROUP_SIZE_M: tl.constexpr, #
ACC_TYPE: tl.constexpr,
NUM_SMS: tl.constexpr,
SCALING_ROWWISE: tl.constexpr, #
SCALING_MODE: tl.constexpr, #
WARP_SPECIALIZE: tl.constexpr,
EPILOGUE_SUBTILE: tl.constexpr,
): #
Expand Down Expand Up @@ -538,7 +539,7 @@ def blackwell_persistent_tma_kernel(
tile_id_c = start_pid - NUM_SMS
num_pid_in_group = GROUP_SIZE_M * num_pid_n

if SCALING_ROWWISE:
if SCALING_MODE == ScalingType.RowWise:
# For row-wise scaling, we'll use the pointers as-is
scale_a = scale_a_ptr
scale_b = scale_b_ptr
Expand All @@ -563,7 +564,7 @@ def blackwell_persistent_tma_kernel(
b_block = b_desc.load([offs_bn, offs_k])
accumulator = tl.dot(a_block, b_block.T, accumulator, out_dtype=tl.float32)

if SCALING_ROWWISE:
if SCALING_MODE == ScalingType.RowWise:
offs_scale_m = offs_am + tl.arange(0, BLOCK_SIZE_M)
offs_scale_n = offs_bn + tl.arange(0, BLOCK_SIZE_N)

Expand Down
Loading