Skip to content

Commit 41d6363

Browse files
jananisriramfacebook-github-bot
authored andcommitted
Refactor fp8_gemm benchmark to simplify addition of new scaling modes
Summary: Refactor the `fp8_gemm` benchmark in TritonBench to accept scaling modes as an argument. This diff enables us to extend the `fp8_gemm` benchmark to new scaling modes without adding new benchmarking arguments. Differential Revision: D83617233
1 parent 496c120 commit 41d6363

File tree

2 files changed

+22
-12
lines changed

2 files changed

+22
-12
lines changed

tritonbench/operators/fp8_gemm/fp8_gemm.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@
4646
def parse_args(args):
4747
parser = argparse.ArgumentParser(description="TritonBench fp8_gemm")
4848
parser.add_argument("--llama", action="store_true")
49-
parser.add_argument("--scaling_rowwise", action="store_true")
49+
parser.add_argument("--scaling-mode", type=str, default="tensor")
5050
parser.add_argument("--m", type=int)
5151
parser.add_argument("--k", type=int)
5252
parser.add_argument("--n", type=int)
@@ -55,6 +55,15 @@ def parse_args(args):
5555
return parser.parse_args(args)
5656

5757

58+
def get_scaling_mode_int(scaling_mode: str) -> int:
59+
if scaling_mode == "tensor":
60+
return 0
61+
elif scaling_mode == "row":
62+
return 1
63+
else:
64+
raise ValueError(f"Invalid scaling mode: {scaling_mode}")
65+
66+
5867
class Operator(BenchmarkOperator):
5968
DEFAULT_METRICS = ["tflops", "gbps", "latency"]
6069
DEFAULT_PRECISION = "fp8"
@@ -65,11 +74,12 @@ def __init__(
6574
super().__init__(tb_args, extra_args)
6675
self.extra_args = parse_args(extra_args)
6776

77+
self.scaling_mode_int = get_scaling_mode_int(self.extra_args.scaling_mode)
78+
6879
def _get_dtype(self):
69-
if self.extra_args.scaling_rowwise:
70-
return torch.bfloat16
71-
else:
80+
if self.scaling_mode_int == 0:
7281
return torch.float16
82+
return torch.bfloat16
7383

7484
def get_input_iter(self):
7585
def _get_scale_per_tensor(
@@ -102,10 +112,10 @@ def args(m, n, k):
102112
a = torch.randn(m, k, device=self.device).to(self._get_dtype())
103113
b = torch.randn(n, k, device=self.device).to(self._get_dtype())
104114

105-
if self.extra_args.scaling_rowwise:
115+
if self.scaling_mode_int == 1:
106116
scale_a = _get_scale_per_row(a)
107117
scale_b = _get_scale_per_row(b)
108-
else:
118+
else: # self.scaling_mode_int == 0
109119
scale_a = _get_scale_per_tensor(
110120
a, custom_scale=self.extra_args.per_tensor_scale_a
111121
)
@@ -191,7 +201,7 @@ def blackwell_persistent_tma_fp8_gemm(self, a, b, scale_a, scale_b):
191201
scale_a,
192202
scale_b,
193203
self._get_dtype(),
194-
self.extra_args.scaling_rowwise,
204+
self.scaling_mode_int,
195205
)
196206

197207
@register_benchmark(enabled=True)

tritonbench/operators/fp8_gemm/persistent.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -411,7 +411,7 @@ def matmul_tma_persistent(a, b, c, desc_a, desc_b, desc_c):
411411

412412

413413
def blackwell_persistent_tma(
414-
a, b, scale_a_ptr, scale_b_ptr, acc_dtype, scaling_rowwise
414+
a, b, scale_a_ptr, scale_b_ptr, acc_dtype, scaling_mode
415415
):
416416
configs = matmul_configs_blackwell()
417417

@@ -471,7 +471,7 @@ def alloc_fn(size: int, align: int, stream: Optional[int]):
471471
NUM_SMS=NUM_SMS, #
472472
num_stages=configs[shape_dtype]["num_stages"], #
473473
num_warps=configs[shape_dtype]["num_warps"], #
474-
SCALING_ROWWISE=scaling_rowwise,
474+
SCALING_MODE=scaling_mode, #
475475
WARP_SPECIALIZE=configs[shape_dtype]["WARP_SPECIALIZE"], #
476476
EPILOGUE_SUBTILE=configs[shape_dtype]["EPILOGUE_SUBTILE"], #
477477
)
@@ -504,7 +504,7 @@ def blackwell_persistent_tma_kernel(
504504
GROUP_SIZE_M: tl.constexpr, #
505505
ACC_TYPE: tl.constexpr,
506506
NUM_SMS: tl.constexpr,
507-
SCALING_ROWWISE: tl.constexpr, #
507+
SCALING_MODE: tl.constexpr, #
508508
WARP_SPECIALIZE: tl.constexpr,
509509
EPILOGUE_SUBTILE: tl.constexpr,
510510
): #
@@ -538,7 +538,7 @@ def blackwell_persistent_tma_kernel(
538538
tile_id_c = start_pid - NUM_SMS
539539
num_pid_in_group = GROUP_SIZE_M * num_pid_n
540540

541-
if SCALING_ROWWISE:
541+
if SCALING_MODE == 1:
542542
# For row-wise scaling, we'll use the pointers as-is
543543
scale_a = scale_a_ptr
544544
scale_b = scale_b_ptr
@@ -563,7 +563,7 @@ def blackwell_persistent_tma_kernel(
563563
b_block = b_desc.load([offs_bn, offs_k])
564564
accumulator = tl.dot(a_block, b_block.T, accumulator, out_dtype=tl.float32)
565565

566-
if SCALING_ROWWISE:
566+
if SCALING_MODE == 1:
567567
offs_scale_m = offs_am + tl.arange(0, BLOCK_SIZE_M)
568568
offs_scale_n = offs_bn + tl.arange(0, BLOCK_SIZE_N)
569569

0 commit comments

Comments
 (0)