Skip to content

Commit 0e555dd

Browse files
jananisriramfacebook-github-bot
authored andcommitted
Refactor fp8_gemm benchmark to simplify addition of new scaling modes (#500)
Summary: Refactor the `fp8_gemm` benchmark in TritonBench to accept scaling modes as an argument. This diff enables us to extend the `fp8_gemm` benchmark to new scaling modes without adding new benchmarking arguments. Reviewed By: NikhilAPatel Differential Revision: D83617233
1 parent ff11de3 commit 0e555dd

File tree

2 files changed

+90
-48
lines changed

2 files changed

+90
-48
lines changed

tritonbench/operators/fp8_gemm/fp8_gemm.py

Lines changed: 82 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import argparse
2+
23
import logging
34

45
from typing import Any, Callable, List, Optional
@@ -7,6 +8,8 @@
78
import torch._inductor.config as inductor_config
89
import triton
910

11+
from torch._inductor.kernel.mm import scaling_pairs, ScalingType
12+
1013
from tritonbench.operators.fp8_gemm.persistent import blackwell_persistent_tma
1114
from tritonbench.utils.env_utils import get_nvidia_gpu_model, is_cuda
1215

@@ -46,7 +49,7 @@
4649
def parse_args(args):
4750
parser = argparse.ArgumentParser(description="TritonBench fp8_gemm")
4851
parser.add_argument("--llama", action="store_true")
49-
parser.add_argument("--scaling_rowwise", action="store_true")
52+
parser.add_argument("--scaling-pair", type=str, default="TensorWise,TensorWise")
5053
parser.add_argument("--m", type=int)
5154
parser.add_argument("--k", type=int)
5255
parser.add_argument("--n", type=int)
@@ -55,6 +58,58 @@ def parse_args(args):
5558
return parser.parse_args(args)
5659

5760

61+
def get_scaling_recipe_int(scaling_recipe: str) -> int:
62+
if scaling_recipe == "TensorWise":
63+
return ScalingType.TensorWise
64+
elif scaling_recipe == "RowWise":
65+
return ScalingType.RowWise
66+
else:
67+
raise ValueError(f"Invalid scaling recipe: {scaling_recipe}")
68+
69+
70+
def get_scale(
71+
x: torch.Tensor,
72+
scaling_recipe_int: int,
73+
transpose: bool = False,
74+
custom_scale: float = None,
75+
) -> (torch.Tensor, torch.Tensor):
76+
def _get_scale_per_tensor(
77+
x: torch.Tensor, custom_scale: float = None
78+
) -> (torch.Tensor, torch.Tensor):
79+
# For tensor-wise scaling, kernel requires a float32 scale tensor
80+
if custom_scale:
81+
return torch.tensor(custom_scale, dtype=torch.float32, device=x.device)
82+
scale = torch.finfo(torch.float8_e4m3fn).max / x.abs().max()
83+
x *= scale
84+
return x, scale.to(torch.float32)
85+
86+
def _get_scale_per_row(
87+
x: torch.Tensor, transpose: bool = False
88+
) -> (torch.Tensor, torch.Tensor):
89+
if transpose: # scale_b.shape should be [1, N]
90+
scale = (
91+
torch.finfo(torch.float8_e4m3fn).max
92+
/ x.abs().max(dim=0, keepdim=True).values
93+
)
94+
else: # scale_a.shape should be [M, 1]
95+
scale = (
96+
torch.finfo(torch.float8_e4m3fn).max
97+
/ x.abs().max(dim=1, keepdim=True).values
98+
)
99+
x = x.mul(scale)
100+
return x, scale.to(
101+
torch.float32
102+
) # For row-wise scaling, kernel requires a float32 scale tensor
103+
104+
match scaling_recipe_int:
105+
case ScalingType.TensorWise:
106+
return _get_scale_per_tensor(x, custom_scale=custom_scale)
107+
case ScalingType.RowWise:
108+
return _get_scale_per_row(x, transpose=transpose)
109+
case _:
110+
raise AssertionError(f"Unsupported scaling type {scaling_recipe_int}")
111+
112+
58113
class Operator(BenchmarkOperator):
59114
DEFAULT_METRICS = ["tflops", "gbps", "latency"]
60115
DEFAULT_PRECISION = "fp8"
@@ -66,53 +121,39 @@ def __init__(
66121
super().__init__(tb_args, extra_args)
67122
self.extra_args = parse_args(extra_args)
68123

124+
scaling_recipe_a, scaling_recipe_b = self.extra_args.scaling_pair.split(",")
125+
if (scaling_recipe_a, scaling_recipe_b) not in [
126+
(a.name, b.name) for a, b in scaling_pairs
127+
]:
128+
raise ValueError(
129+
f"Invalid scaling pair: {scaling_recipe_a}, {scaling_recipe_b}. See torch/_inductor/kernel/mm.py::scaling_pairs for valid pairs."
130+
)
131+
self.scaling_recipe_a_int = get_scaling_recipe_int(scaling_recipe_a).value
132+
self.scaling_recipe_b_int = get_scaling_recipe_int(scaling_recipe_b).value
133+
69134
def _get_dtype(self):
70-
if self.extra_args.scaling_rowwise:
71-
return torch.bfloat16
72-
else:
135+
if (
136+
self.scaling_recipe_a_int == ScalingType.TensorWise
137+
and self.scaling_recipe_b_int == ScalingType.TensorWise
138+
):
73139
return torch.float16
140+
return torch.bfloat16
74141

75142
def get_input_iter(self):
76-
def _get_scale_per_tensor(
77-
x: torch.Tensor, custom_scale: float = None
78-
) -> torch.Tensor:
79-
# For tensor-wise scaling, kernel requires a float32 scale tensor
80-
if custom_scale:
81-
return torch.tensor(custom_scale, dtype=torch.float32, device=x.device)
82-
scale = torch.finfo(torch.float8_e4m3fn).max / x.abs().max()
83-
return scale.to(torch.float32)
84-
85-
def _get_scale_per_row(
86-
x: torch.Tensor, transpose: bool = False
87-
) -> torch.Tensor:
88-
if transpose: # scale_b.shape should be [1, N]
89-
scale = (
90-
torch.finfo(torch.float8_e4m3fn).max
91-
/ x.abs().max(dim=0, keepdim=True).values
92-
)
93-
else: # scale_a.shape should be [M, 1]
94-
scale = (
95-
torch.finfo(torch.float8_e4m3fn).max
96-
/ x.abs().max(dim=1, keepdim=True).values
97-
)
98-
return scale.to(
99-
torch.float32
100-
) # For row-wise scaling, kernel requires a float32 scale tensor
101-
102143
def args(m, n, k):
103144
a = torch.randn(m, k, device=self.device).to(self._get_dtype())
104145
b = torch.randn(n, k, device=self.device).to(self._get_dtype())
105146

106-
if self.extra_args.scaling_rowwise:
107-
scale_a = _get_scale_per_row(a)
108-
scale_b = _get_scale_per_row(b)
109-
else:
110-
scale_a = _get_scale_per_tensor(
111-
a, custom_scale=self.extra_args.per_tensor_scale_a
112-
)
113-
scale_b = _get_scale_per_tensor(
114-
b, custom_scale=self.extra_args.per_tensor_scale_b
115-
)
147+
a, scale_a = get_scale(
148+
a,
149+
self.scaling_recipe_a_int,
150+
custom_scale=self.extra_args.per_tensor_scale_a,
151+
)
152+
b, scale_b = get_scale(
153+
b,
154+
self.scaling_recipe_b_int,
155+
custom_scale=self.extra_args.per_tensor_scale_b,
156+
)
116157

117158
# Kernels expect dtype=float8_e4m3fn
118159
a = a.to(torch.float8_e4m3fn)
@@ -192,7 +233,7 @@ def blackwell_persistent_tma_fp8_gemm(self, a, b, scale_a, scale_b):
192233
scale_a,
193234
scale_b,
194235
self._get_dtype(),
195-
self.extra_args.scaling_rowwise,
236+
0 if self.scaling_recipe_a_int == self.scaling_recipe_b_int == 0 else 1,
196237
)
197238

198239
@register_benchmark(enabled=True)

tritonbench/operators/fp8_gemm/persistent.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
from functools import lru_cache
2+
23
from typing import Optional
34

45
import torch
56
import triton
67
import triton.language as tl
78

9+
from torch._inductor.kernel.mm import ScalingType
10+
811
from tritonbench.utils.env_utils import is_cuda
912
from tritonbench.utils.triton_utils import has_experimental_descriptor
1013

@@ -410,9 +413,7 @@ def matmul_tma_persistent(a, b, c, desc_a, desc_b, desc_c):
410413
# - 1 warp = 32 threads, so each thread block requires 128 / 32 = 4 warps
411414

412415

413-
def blackwell_persistent_tma(
414-
a, b, scale_a_ptr, scale_b_ptr, acc_dtype, scaling_rowwise
415-
):
416+
def blackwell_persistent_tma(a, b, scale_a_ptr, scale_b_ptr, acc_dtype, scaling_mode):
416417
configs = matmul_configs_blackwell()
417418

418419
# Check constraints.
@@ -471,7 +472,7 @@ def alloc_fn(size: int, align: int, stream: Optional[int]):
471472
NUM_SMS=NUM_SMS, #
472473
num_stages=configs[shape_dtype]["num_stages"], #
473474
num_warps=configs[shape_dtype]["num_warps"], #
474-
SCALING_ROWWISE=scaling_rowwise,
475+
SCALING_MODE=scaling_mode, #
475476
WARP_SPECIALIZE=configs[shape_dtype]["WARP_SPECIALIZE"], #
476477
EPILOGUE_SUBTILE=configs[shape_dtype]["EPILOGUE_SUBTILE"], #
477478
)
@@ -504,7 +505,7 @@ def blackwell_persistent_tma_kernel(
504505
GROUP_SIZE_M: tl.constexpr, #
505506
ACC_TYPE: tl.constexpr,
506507
NUM_SMS: tl.constexpr,
507-
SCALING_ROWWISE: tl.constexpr, #
508+
SCALING_MODE: tl.constexpr, #
508509
WARP_SPECIALIZE: tl.constexpr,
509510
EPILOGUE_SUBTILE: tl.constexpr,
510511
): #
@@ -538,7 +539,7 @@ def blackwell_persistent_tma_kernel(
538539
tile_id_c = start_pid - NUM_SMS
539540
num_pid_in_group = GROUP_SIZE_M * num_pid_n
540541

541-
if SCALING_ROWWISE:
542+
if SCALING_MODE == ScalingType.RowWise:
542543
# For row-wise scaling, we'll use the pointers as-is
543544
scale_a = scale_a_ptr
544545
scale_b = scale_b_ptr
@@ -563,7 +564,7 @@ def blackwell_persistent_tma_kernel(
563564
b_block = b_desc.load([offs_bn, offs_k])
564565
accumulator = tl.dot(a_block, b_block.T, accumulator, out_dtype=tl.float32)
565566

566-
if SCALING_ROWWISE:
567+
if SCALING_MODE == ScalingType.RowWise:
567568
offs_scale_m = offs_am + tl.arange(0, BLOCK_SIZE_M)
568569
offs_scale_n = offs_bn + tl.arange(0, BLOCK_SIZE_N)
569570

0 commit comments

Comments
 (0)