Skip to content

Commit 2e928c2

Browse files
jananisriramfacebook-github-bot
authored andcommitted
Refactor fp8_gemm benchmark to simplify addition of new scaling modes (#500)
Summary: Refactor the `fp8_gemm` benchmark in TritonBench to accept scaling modes as an argument. This diff enables us to extend the `fp8_gemm` benchmark to new scaling modes without adding new benchmarking arguments. Differential Revision: D83617233
1 parent 496c120 commit 2e928c2

File tree

2 files changed

+22
-14
lines changed

2 files changed

+22
-14
lines changed

tritonbench/operators/fp8_gemm/fp8_gemm.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@
4646
def parse_args(args):
4747
parser = argparse.ArgumentParser(description="TritonBench fp8_gemm")
4848
parser.add_argument("--llama", action="store_true")
49-
parser.add_argument("--scaling_rowwise", action="store_true")
49+
parser.add_argument("--scaling-mode", type=str, default="tensor")
5050
parser.add_argument("--m", type=int)
5151
parser.add_argument("--k", type=int)
5252
parser.add_argument("--n", type=int)
@@ -55,6 +55,15 @@ def parse_args(args):
5555
return parser.parse_args(args)
5656

5757

58+
def get_scaling_mode_int(scaling_mode: str) -> int:
59+
if scaling_mode == "tensor":
60+
return 0
61+
elif scaling_mode == "row":
62+
return 1
63+
else:
64+
raise ValueError(f"Invalid scaling mode: {scaling_mode}")
65+
66+
5867
class Operator(BenchmarkOperator):
5968
DEFAULT_METRICS = ["tflops", "gbps", "latency"]
6069
DEFAULT_PRECISION = "fp8"
@@ -65,11 +74,12 @@ def __init__(
6574
super().__init__(tb_args, extra_args)
6675
self.extra_args = parse_args(extra_args)
6776

77+
self.scaling_mode_int = get_scaling_mode_int(self.extra_args.scaling_mode)
78+
6879
def _get_dtype(self):
69-
if self.extra_args.scaling_rowwise:
70-
return torch.bfloat16
71-
else:
80+
if self.scaling_mode_int == 0:
7281
return torch.float16
82+
return torch.bfloat16
7383

7484
def get_input_iter(self):
7585
def _get_scale_per_tensor(
@@ -102,10 +112,10 @@ def args(m, n, k):
102112
a = torch.randn(m, k, device=self.device).to(self._get_dtype())
103113
b = torch.randn(n, k, device=self.device).to(self._get_dtype())
104114

105-
if self.extra_args.scaling_rowwise:
115+
if self.scaling_mode_int == 1:
106116
scale_a = _get_scale_per_row(a)
107117
scale_b = _get_scale_per_row(b)
108-
else:
118+
else: # self.scaling_mode_int == 0
109119
scale_a = _get_scale_per_tensor(
110120
a, custom_scale=self.extra_args.per_tensor_scale_a
111121
)
@@ -191,7 +201,7 @@ def blackwell_persistent_tma_fp8_gemm(self, a, b, scale_a, scale_b):
191201
scale_a,
192202
scale_b,
193203
self._get_dtype(),
194-
self.extra_args.scaling_rowwise,
204+
self.scaling_mode_int,
195205
)
196206

197207
@register_benchmark(enabled=True)

tritonbench/operators/fp8_gemm/persistent.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -410,9 +410,7 @@ def matmul_tma_persistent(a, b, c, desc_a, desc_b, desc_c):
410410
# - 1 warp = 32 threads, so each thread block requires 128 / 32 = 4 warps
411411

412412

413-
def blackwell_persistent_tma(
414-
a, b, scale_a_ptr, scale_b_ptr, acc_dtype, scaling_rowwise
415-
):
413+
def blackwell_persistent_tma(a, b, scale_a_ptr, scale_b_ptr, acc_dtype, scaling_mode):
416414
configs = matmul_configs_blackwell()
417415

418416
# Check constraints.
@@ -471,7 +469,7 @@ def alloc_fn(size: int, align: int, stream: Optional[int]):
471469
NUM_SMS=NUM_SMS, #
472470
num_stages=configs[shape_dtype]["num_stages"], #
473471
num_warps=configs[shape_dtype]["num_warps"], #
474-
SCALING_ROWWISE=scaling_rowwise,
472+
SCALING_MODE=scaling_mode, #
475473
WARP_SPECIALIZE=configs[shape_dtype]["WARP_SPECIALIZE"], #
476474
EPILOGUE_SUBTILE=configs[shape_dtype]["EPILOGUE_SUBTILE"], #
477475
)
@@ -504,7 +502,7 @@ def blackwell_persistent_tma_kernel(
504502
GROUP_SIZE_M: tl.constexpr, #
505503
ACC_TYPE: tl.constexpr,
506504
NUM_SMS: tl.constexpr,
507-
SCALING_ROWWISE: tl.constexpr, #
505+
SCALING_MODE: tl.constexpr, #
508506
WARP_SPECIALIZE: tl.constexpr,
509507
EPILOGUE_SUBTILE: tl.constexpr,
510508
): #
@@ -538,7 +536,7 @@ def blackwell_persistent_tma_kernel(
538536
tile_id_c = start_pid - NUM_SMS
539537
num_pid_in_group = GROUP_SIZE_M * num_pid_n
540538

541-
if SCALING_ROWWISE:
539+
if SCALING_MODE == 1:
542540
# For row-wise scaling, we'll use the pointers as-is
543541
scale_a = scale_a_ptr
544542
scale_b = scale_b_ptr
@@ -563,7 +561,7 @@ def blackwell_persistent_tma_kernel(
563561
b_block = b_desc.load([offs_bn, offs_k])
564562
accumulator = tl.dot(a_block, b_block.T, accumulator, out_dtype=tl.float32)
565563

566-
if SCALING_ROWWISE:
564+
if SCALING_MODE == 1:
567565
offs_scale_m = offs_am + tl.arange(0, BLOCK_SIZE_M)
568566
offs_scale_n = offs_bn + tl.arange(0, BLOCK_SIZE_N)
569567

0 commit comments

Comments
 (0)