Skip to content

Commit d6db1d2

Browse files
committed
Add options to print SW efficiency
1 parent 6f1525f commit d6db1d2

File tree

1 file changed

+29
-5
lines changed

1 file changed

+29
-5
lines changed

benchmarks/triton_kernels_benchmark/benchmark_testing.py

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -352,6 +352,8 @@ class MarkArgs:
352352
reports: str = ""
353353
n_runs: int = 1
354354
brief: bool = False
355+
hw_gbps: float = None
356+
hw_tflops: float = None
355357

356358
@staticmethod
357359
def load_cli_args() -> MarkArgs:
@@ -375,8 +377,32 @@ def load_cli_args() -> MarkArgs:
375377
action="store_true",
376378
help="Print only mean values without min, max, CV.",
377379
)
380+
parser.add_argument(
381+
"--hw_gbps",
382+
type=float,
383+
help="Hardware bandwidth in GB/s to calculate efficiency.",
384+
)
385+
parser.add_argument(
386+
"--hw_tflops",
387+
type=float,
388+
help="Hardware peak performance in TFLOPS to calculate efficiency.",
389+
)
378390
args = parser.parse_args()
379-
return MarkArgs(args.reports, args.n_runs, args.brief)
391+
return MarkArgs(args.reports, args.n_runs, args.brief, args.hw_gbps, args.hw_tflops)
392+
393+
394+
def enhance_df(df, mark_args: MarkArgs):
395+
df = df.copy()
396+
if mark_args.brief:
397+
df = df[[c for c in df.columns if not any(map(c.endswith, ("min", "max", "CV")))]]
398+
399+
for col in df.columns:
400+
if col.lower().replace("/", "p").endswith("gbps") and mark_args.hw_gbps:
401+
df[col + "-eff"] = (df[col] / mark_args.hw_gbps).apply(lambda x: f"{x:.1%}")
402+
elif col.lower().endswith("tflops") and mark_args.hw_tflops:
403+
df[col + "-eff"] = (df[col] / mark_args.hw_tflops).apply(lambda x: f"{x:.1%}")
404+
405+
return df
380406

381407

382408
class Mark:
@@ -462,12 +488,10 @@ def _run(self, bench: Benchmark, save_path: str, show_plots: bool, print_data: b
462488
col0, col1 = df.columns.tolist()
463489
df["Diff"] = df[col1] - df[col0]
464490

491+
df = enhance_df(df, mark_args)
465492
if print_data:
466493
print(bench.plot_name + ":")
467-
if mark_args.brief:
468-
print(df[[c for c in df.columns if not any(map(c.endswith, ("min", "max", "CV")))]].to_string())
469-
else:
470-
print(df.to_string())
494+
print(df.to_string())
471495

472496
if save_path:
473497
df.to_csv(os.path.join(save_path, f"{filename}.csv"), float_format=f"%.{save_precision}f", index=False)

0 commit comments

Comments
 (0)