Skip to content

Commit 5fe61db

Browse files
jananisriramfacebook-github-bot
authored andcommitted
Refactor fp8_gemm benchmark to simplify addition of new scaling modes (#500)
Summary: Refactor the `fp8_gemm` benchmark in TritonBench to accept scaling modes as an argument. This diff enables us to extend the `fp8_gemm` benchmark to new scaling modes without adding new benchmarking arguments. Reviewed By: NikhilAPatel Differential Revision: D83617233
1 parent e7c435c commit 5fe61db

File tree

2 files changed

+86
-48
lines changed

2 files changed

+86
-48
lines changed

tritonbench/operators/fp8_gemm/fp8_gemm.py

Lines changed: 78 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import argparse
2+
23
import logging
34

45
from typing import Any, Callable, List, Optional
@@ -7,6 +8,8 @@
78
import torch._inductor.config as inductor_config
89
import triton
910

11+
from torch._inductor.kernel.mm import scaling_pairs, ScalingType
12+
1013
from tritonbench.operators.fp8_gemm.persistent import blackwell_persistent_tma
1114
from tritonbench.utils.env_utils import get_nvidia_gpu_model, is_cuda
1215

@@ -46,7 +49,7 @@
4649
def parse_args(args):
4750
parser = argparse.ArgumentParser(description="TritonBench fp8_gemm")
4851
parser.add_argument("--llama", action="store_true")
49-
parser.add_argument("--scaling_rowwise", action="store_true")
52+
parser.add_argument("--scaling-pair", type=str, default="TensorWise,TensorWise")
5053
parser.add_argument("--m", type=int)
5154
parser.add_argument("--k", type=int)
5255
parser.add_argument("--n", type=int)
@@ -55,6 +58,54 @@ def parse_args(args):
5558
return parser.parse_args(args)
5659

5760

61+
def get_scaling_recipe_int(scaling_recipe: str) -> int:
62+
if scaling_recipe == "TensorWise":
63+
return ScalingType.TensorWise
64+
elif scaling_recipe == "RowWise":
65+
return ScalingType.RowWise
66+
else:
67+
raise ValueError(f"Invalid scaling recipe: {scaling_recipe}")
68+
69+
70+
def get_scale(
71+
x: torch.Tensor,
72+
scaling_recipe_int: int,
73+
transpose: bool = False,
74+
custom_scale: float = None,
75+
):
76+
def _get_scale_per_tensor(
77+
x: torch.Tensor, custom_scale: float = None
78+
) -> torch.Tensor:
79+
# For tensor-wise scaling, kernel requires a float32 scale tensor
80+
if custom_scale:
81+
return torch.tensor(custom_scale, dtype=torch.float32, device=x.device)
82+
scale = torch.finfo(torch.float8_e4m3fn).max / x.abs().max()
83+
return scale.to(torch.float32)
84+
85+
def _get_scale_per_row(x: torch.Tensor, transpose: bool = False) -> torch.Tensor:
86+
if transpose: # scale_b.shape should be [1, N]
87+
scale = (
88+
torch.finfo(torch.float8_e4m3fn).max
89+
/ x.abs().max(dim=0, keepdim=True).values
90+
)
91+
else: # scale_a.shape should be [M, 1]
92+
scale = (
93+
torch.finfo(torch.float8_e4m3fn).max
94+
/ x.abs().max(dim=1, keepdim=True).values
95+
)
96+
return scale.to(
97+
torch.float32
98+
) # For row-wise scaling, kernel requires a float32 scale tensor
99+
100+
match scaling_recipe_int:
101+
case ScalingType.TensorWise:
102+
return _get_scale_per_tensor(x, custom_scale=custom_scale)
103+
case ScalingType.RowWise:
104+
return _get_scale_per_row(x, transpose=transpose)
105+
case _:
106+
raise AssertionError(f"Unsupported scaling type {scaling_recipe_int}")
107+
108+
58109
class Operator(BenchmarkOperator):
59110
DEFAULT_METRICS = ["tflops", "gbps", "latency"]
60111
DEFAULT_PRECISION = "fp8"
@@ -66,53 +117,39 @@ def __init__(
66117
super().__init__(tb_args, extra_args)
67118
self.extra_args = parse_args(extra_args)
68119

120+
scaling_recipe_a, scaling_recipe_b = self.extra_args.scaling_pair.split(",")
121+
if (scaling_recipe_a, scaling_recipe_b) not in [
122+
(a.name, b.name) for a, b in scaling_pairs
123+
]:
124+
raise ValueError(
125+
f"Invalid scaling pair: {scaling_recipe_a}, {scaling_recipe_b}. See torch/_inductor/kernel/mm.py::scaling_pairs for valid pairs."
126+
)
127+
self.scaling_recipe_a_int = get_scaling_recipe_int(scaling_recipe_a).value
128+
self.scaling_recipe_b_int = get_scaling_recipe_int(scaling_recipe_b).value
129+
69130
def _get_dtype(self):
70-
if self.extra_args.scaling_rowwise:
71-
return torch.bfloat16
72-
else:
131+
if (
132+
self.scaling_recipe_a_int == ScalingType.TensorWise
133+
and self.scaling_recipe_b_int == ScalingType.TensorWise
134+
):
73135
return torch.float16
136+
return torch.bfloat16
74137

75138
def get_input_iter(self):
76-
def _get_scale_per_tensor(
77-
x: torch.Tensor, custom_scale: float = None
78-
) -> torch.Tensor:
79-
# For tensor-wise scaling, kernel requires a float32 scale tensor
80-
if custom_scale:
81-
return torch.tensor(custom_scale, dtype=torch.float32, device=x.device)
82-
scale = torch.finfo(torch.float8_e4m3fn).max / x.abs().max()
83-
return scale.to(torch.float32)
84-
85-
def _get_scale_per_row(
86-
x: torch.Tensor, transpose: bool = False
87-
) -> torch.Tensor:
88-
if transpose: # scale_b.shape should be [1, N]
89-
scale = (
90-
torch.finfo(torch.float8_e4m3fn).max
91-
/ x.abs().max(dim=0, keepdim=True).values
92-
)
93-
else: # scale_a.shape should be [M, 1]
94-
scale = (
95-
torch.finfo(torch.float8_e4m3fn).max
96-
/ x.abs().max(dim=1, keepdim=True).values
97-
)
98-
return scale.to(
99-
torch.float32
100-
) # For row-wise scaling, kernel requires a float32 scale tensor
101-
102139
def args(m, n, k):
103140
a = torch.randn(m, k, device=self.device).to(self._get_dtype())
104141
b = torch.randn(n, k, device=self.device).to(self._get_dtype())
105142

106-
if self.extra_args.scaling_rowwise:
107-
scale_a = _get_scale_per_row(a)
108-
scale_b = _get_scale_per_row(b)
109-
else:
110-
scale_a = _get_scale_per_tensor(
111-
a, custom_scale=self.extra_args.per_tensor_scale_a
112-
)
113-
scale_b = _get_scale_per_tensor(
114-
b, custom_scale=self.extra_args.per_tensor_scale_b
115-
)
143+
scale_a = get_scale(
144+
a,
145+
self.scaling_recipe_a_int,
146+
custom_scale=self.extra_args.per_tensor_scale_a,
147+
)
148+
scale_b = get_scale(
149+
b,
150+
self.scaling_recipe_b_int,
151+
custom_scale=self.extra_args.per_tensor_scale_b,
152+
)
116153

117154
# Kernels expect dtype=float8_e4m3fn
118155
a = a.to(torch.float8_e4m3fn)
@@ -192,7 +229,7 @@ def blackwell_persistent_tma_fp8_gemm(self, a, b, scale_a, scale_b):
192229
scale_a,
193230
scale_b,
194231
self._get_dtype(),
195-
self.extra_args.scaling_rowwise,
232+
0 if self.scaling_recipe_a_int == self.scaling_recipe_b_int == 0 else 1,
196233
)
197234

198235
@register_benchmark(enabled=True)

tritonbench/operators/fp8_gemm/persistent.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
from functools import lru_cache
2+
23
from typing import Optional
34

45
import torch
56
import triton
67
import triton.language as tl
78

9+
from torch._inductor.kernel.mm import ScalingType
10+
811
from tritonbench.utils.env_utils import is_cuda
912
from tritonbench.utils.triton_utils import has_experimental_descriptor
1013

@@ -410,9 +413,7 @@ def matmul_tma_persistent(a, b, c, desc_a, desc_b, desc_c):
410413
# - 1 warp = 32 threads, so each thread block requires 128 / 32 = 4 warps
411414

412415

413-
def blackwell_persistent_tma(
414-
a, b, scale_a_ptr, scale_b_ptr, acc_dtype, scaling_rowwise
415-
):
416+
def blackwell_persistent_tma(a, b, scale_a_ptr, scale_b_ptr, acc_dtype, scaling_mode):
416417
configs = matmul_configs_blackwell()
417418

418419
# Check constraints.
@@ -471,7 +472,7 @@ def alloc_fn(size: int, align: int, stream: Optional[int]):
471472
NUM_SMS=NUM_SMS, #
472473
num_stages=configs[shape_dtype]["num_stages"], #
473474
num_warps=configs[shape_dtype]["num_warps"], #
474-
SCALING_ROWWISE=scaling_rowwise,
475+
SCALING_MODE=scaling_mode, #
475476
WARP_SPECIALIZE=configs[shape_dtype]["WARP_SPECIALIZE"], #
476477
EPILOGUE_SUBTILE=configs[shape_dtype]["EPILOGUE_SUBTILE"], #
477478
)
@@ -504,7 +505,7 @@ def blackwell_persistent_tma_kernel(
504505
GROUP_SIZE_M: tl.constexpr, #
505506
ACC_TYPE: tl.constexpr,
506507
NUM_SMS: tl.constexpr,
507-
SCALING_ROWWISE: tl.constexpr, #
508+
SCALING_MODE: tl.constexpr, #
508509
WARP_SPECIALIZE: tl.constexpr,
509510
EPILOGUE_SUBTILE: tl.constexpr,
510511
): #
@@ -538,7 +539,7 @@ def blackwell_persistent_tma_kernel(
538539
tile_id_c = start_pid - NUM_SMS
539540
num_pid_in_group = GROUP_SIZE_M * num_pid_n
540541

541-
if SCALING_ROWWISE:
542+
if SCALING_MODE == ScalingType.RowWise:
542543
# For row-wise scaling, we'll use the pointers as-is
543544
scale_a = scale_a_ptr
544545
scale_b = scale_b_ptr
@@ -563,7 +564,7 @@ def blackwell_persistent_tma_kernel(
563564
b_block = b_desc.load([offs_bn, offs_k])
564565
accumulator = tl.dot(a_block, b_block.T, accumulator, out_dtype=tl.float32)
565566

566-
if SCALING_ROWWISE:
567+
if SCALING_MODE == ScalingType.RowWise:
567568
offs_scale_m = offs_am + tl.arange(0, BLOCK_SIZE_M)
568569
offs_scale_n = offs_bn + tl.arange(0, BLOCK_SIZE_N)
569570

0 commit comments

Comments
 (0)