Skip to content

Commit 9df1f8d

Browse files
jananisriramfacebook-github-bot
authored andcommitted
Refactor fp8_gemm benchmark to simplify addition of new scaling modes (#500)
Summary: Refactor the `fp8_gemm` benchmark in TritonBench to accept scaling modes as an argument. This diff enables us to extend the `fp8_gemm` benchmark to new scaling modes without adding new benchmarking arguments. Reviewed By: NikhilAPatel Differential Revision: D83617233
1 parent d8b41f2 commit 9df1f8d

File tree

2 files changed

+28
-14
lines changed

2 files changed

+28
-14
lines changed

tritonbench/operators/fp8_gemm/fp8_gemm.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import argparse
2+
23
import logging
34

45
from typing import Any, Callable, List, Optional
@@ -7,6 +8,8 @@
78
import torch._inductor.config as inductor_config
89
import triton
910

11+
from torch._inductor.kernel.mm import ScalingMode
12+
1013
from tritonbench.operators.fp8_gemm.persistent import blackwell_persistent_tma
1114
from tritonbench.utils.env_utils import get_nvidia_gpu_model, is_cuda
1215

@@ -46,7 +49,7 @@
4649
def parse_args(args):
4750
parser = argparse.ArgumentParser(description="TritonBench fp8_gemm")
4851
parser.add_argument("--llama", action="store_true")
49-
parser.add_argument("--scaling_rowwise", action="store_true")
52+
parser.add_argument("--scaling-mode", type=str, default="tensor")
5053
parser.add_argument("--m", type=int)
5154
parser.add_argument("--k", type=int)
5255
parser.add_argument("--n", type=int)
@@ -55,6 +58,15 @@ def parse_args(args):
5558
return parser.parse_args(args)
5659

5760

61+
def get_scaling_mode_int(scaling_mode: str) -> int:
62+
if scaling_mode == "tensor":
63+
return ScalingMode.TENSOR
64+
elif scaling_mode == "row":
65+
return ScalingMode.ROW
66+
else:
67+
raise ValueError(f"Invalid scaling mode: {scaling_mode}")
68+
69+
5870
class Operator(BenchmarkOperator):
5971
DEFAULT_METRICS = ["tflops", "gbps", "latency"]
6072
DEFAULT_PRECISION = "fp8"
@@ -65,11 +77,12 @@ def __init__(
6577
super().__init__(tb_args, extra_args)
6678
self.extra_args = parse_args(extra_args)
6779

80+
self.scaling_mode_int = get_scaling_mode_int(self.extra_args.scaling_mode).value
81+
6882
def _get_dtype(self):
69-
if self.extra_args.scaling_rowwise:
70-
return torch.bfloat16
71-
else:
83+
if self.scaling_mode_int == ScalingMode.TENSOR:
7284
return torch.float16
85+
return torch.bfloat16
7386

7487
def get_input_iter(self):
7588
def _get_scale_per_tensor(
@@ -102,10 +115,10 @@ def args(m, n, k):
102115
a = torch.randn(m, k, device=self.device).to(self._get_dtype())
103116
b = torch.randn(n, k, device=self.device).to(self._get_dtype())
104117

105-
if self.extra_args.scaling_rowwise:
118+
if self.scaling_mode_int == ScalingMode.ROW:
106119
scale_a = _get_scale_per_row(a)
107120
scale_b = _get_scale_per_row(b)
108-
else:
121+
else: # self.scaling_mode_int == ScalingMode.TENSOR
109122
scale_a = _get_scale_per_tensor(
110123
a, custom_scale=self.extra_args.per_tensor_scale_a
111124
)
@@ -191,7 +204,7 @@ def blackwell_persistent_tma_fp8_gemm(self, a, b, scale_a, scale_b):
191204
scale_a,
192205
scale_b,
193206
self._get_dtype(),
194-
self.extra_args.scaling_rowwise,
207+
self.scaling_mode_int,
195208
)
196209

197210
@register_benchmark(enabled=True)

tritonbench/operators/fp8_gemm/persistent.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
from functools import lru_cache
2+
23
from typing import Optional
34

45
import torch
56
import triton
67
import triton.language as tl
78

9+
from torch._inductor.kernel.mm import ScalingMode
10+
811
from tritonbench.utils.env_utils import is_cuda
912
from tritonbench.utils.triton_utils import has_experimental_descriptor
1013

@@ -410,9 +413,7 @@ def matmul_tma_persistent(a, b, c, desc_a, desc_b, desc_c):
410413
# - 1 warp = 32 threads, so each thread block requires 128 / 32 = 4 warps
411414

412415

413-
def blackwell_persistent_tma(
414-
a, b, scale_a_ptr, scale_b_ptr, acc_dtype, scaling_rowwise
415-
):
416+
def blackwell_persistent_tma(a, b, scale_a_ptr, scale_b_ptr, acc_dtype, scaling_mode):
416417
configs = matmul_configs_blackwell()
417418

418419
# Check constraints.
@@ -471,7 +472,7 @@ def alloc_fn(size: int, align: int, stream: Optional[int]):
471472
NUM_SMS=NUM_SMS, #
472473
num_stages=configs[shape_dtype]["num_stages"], #
473474
num_warps=configs[shape_dtype]["num_warps"], #
474-
SCALING_ROWWISE=scaling_rowwise,
475+
SCALING_MODE=scaling_mode, #
475476
WARP_SPECIALIZE=configs[shape_dtype]["WARP_SPECIALIZE"], #
476477
EPILOGUE_SUBTILE=configs[shape_dtype]["EPILOGUE_SUBTILE"], #
477478
)
@@ -504,7 +505,7 @@ def blackwell_persistent_tma_kernel(
504505
GROUP_SIZE_M: tl.constexpr, #
505506
ACC_TYPE: tl.constexpr,
506507
NUM_SMS: tl.constexpr,
507-
SCALING_ROWWISE: tl.constexpr, #
508+
SCALING_MODE: tl.constexpr, #
508509
WARP_SPECIALIZE: tl.constexpr,
509510
EPILOGUE_SUBTILE: tl.constexpr,
510511
): #
@@ -538,7 +539,7 @@ def blackwell_persistent_tma_kernel(
538539
tile_id_c = start_pid - NUM_SMS
539540
num_pid_in_group = GROUP_SIZE_M * num_pid_n
540541

541-
if SCALING_ROWWISE:
542+
if SCALING_MODE == ScalingMode.ROW:
542543
# For row-wise scaling, we'll use the pointers as-is
543544
scale_a = scale_a_ptr
544545
scale_b = scale_b_ptr
@@ -563,7 +564,7 @@ def blackwell_persistent_tma_kernel(
563564
b_block = b_desc.load([offs_bn, offs_k])
564565
accumulator = tl.dot(a_block, b_block.T, accumulator, out_dtype=tl.float32)
565566

566-
if SCALING_ROWWISE:
567+
if SCALING_MODE == ScalingMode.ROW:
567568
offs_scale_m = offs_am + tl.arange(0, BLOCK_SIZE_M)
568569
offs_scale_n = offs_bn + tl.arange(0, BLOCK_SIZE_N)
569570

0 commit comments

Comments
 (0)