From 56cba4de107f049cb4daa909174608d92c3c2888 Mon Sep 17 00:00:00 2001 From: Janani Sriram Date: Mon, 18 Aug 2025 09:54:01 -0700 Subject: [PATCH] Add benchmarking on shapes from CSV files to fp8_gemm (#332) Summary: Add ability to benchmark fp8_gemm kernels on shapes from CSV files. Add CLI argument to consume file path for CSV file. Differential Revision: D80381352 --- tritonbench/operators/fp8_gemm/fp8_gemm.py | 28 ++++++++++++++++++++-- tritonbench/operators/gemm/stream_k.py | 2 +- 2 files changed, 27 insertions(+), 3 deletions(-) diff --git a/tritonbench/operators/fp8_gemm/fp8_gemm.py b/tritonbench/operators/fp8_gemm/fp8_gemm.py index 4f0d8043..e4d15385 100644 --- a/tritonbench/operators/fp8_gemm/fp8_gemm.py +++ b/tritonbench/operators/fp8_gemm/fp8_gemm.py @@ -1,4 +1,5 @@ import argparse +import csv import logging from typing import Any, Callable, List, Optional @@ -41,9 +42,28 @@ def parse_args(args): parser.add_argument("--m", type=int) parser.add_argument("--k", type=int) parser.add_argument("--n", type=int) + parser.add_argument("--filepath", type=str, default=None) return parser.parse_args(args) +def read_fp8_shapes(filepath): + fp8_shapes = [] + try: + with open(filepath, "r", newline="") as csvfile: + filtered_lines = ( + line + for line in csvfile + if line.strip() and not line.lstrip().startswith("#") + ) + reader = csv.reader(filtered_lines) + for row in reader: + fp8_shapes.append(tuple(map(int, row))) + except Exception as e: + logger.error(f"Failed to read fp8 shapes from {filepath}: {e}") + raise e + return fp8_shapes + + class Operator(BenchmarkOperator): DEFAULT_METRICS = ["tflops", "gbps", "latency"] DEFAULT_PRECISION = "fp8" @@ -70,6 +90,10 @@ def args(m, n, k): yield args(m, n, k) elif self.extra_args.m: yield args(self.extra_args.m, self.extra_args.n, self.extra_args.k) + elif self.extra_args.filepath: + fp8_shapes = read_fp8_shapes(self.extra_args.filepath) + for m, n, k in fp8_shapes: + yield args(m, n, k) else: for i in range(10, 15): for j in range(0, 4): @@ -114,8 +138,8 @@ def pt2_fp8_gemm(self, a, b) -> Callable: scale_b = torch.ones((1, N), dtype=torch.float32, device=b.device) out_dtype = torch.bfloat16 else: - scale_a = torch.tensor(1.0, device=a.device) - scale_b = torch.tensor(1.0, device=a.device) + scale_a = torch.tensor(1.0, dtype=torch.float32, device=a.device) + scale_b = torch.tensor(1.0, dtype=torch.float32, device=b.device) out_dtype = torch.float16 f = lambda a, b: torch._scaled_mm( a, b, scale_a, scale_b, use_fast_accum=True, out_dtype=out_dtype diff --git a/tritonbench/operators/gemm/stream_k.py b/tritonbench/operators/gemm/stream_k.py index 6cbc857c..169b2b17 100644 --- a/tritonbench/operators/gemm/stream_k.py +++ b/tritonbench/operators/gemm/stream_k.py @@ -646,6 +646,6 @@ def grid(META): K, # FP8_OUTPUT=dtype == torch.float8_e4m3fn, # ENABLE_BUFFER_OPS_ASSUMES=True, # - NUM_SMS=num_sms # + NUM_SMS=num_sms, # ) return c