Skip to content

Commit c33e82e

Browse files
cthifacebook-github-bot
authored andcommitted
Print all available kernels to bench if user doesn't pass valid ones (#4827)
Summary: Pull Request resolved: #4827 X-link: facebookresearch/FBGEMM#1853 We have so many kernels now, it is easier to print them since I always forget the naming Reviewed By: q10 Differential Revision: D81801350 fbshipit-source-id: 515a519e2728fffb5198f3a29827d070f1f165e1
1 parent 3ca2859 commit c33e82e

File tree

1 file changed

+17
-4
lines changed

1 file changed

+17
-4
lines changed

fbgemm_gpu/experimental/gen_ai/bench/quantize_bench.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import itertools
88
import os
9+
import sys
910

1011
from dataclasses import dataclass
1112
from datetime import datetime
@@ -19,6 +20,7 @@
1920
import pandas as pd
2021
import seaborn as sns
2122
import torch
23+
from tabulate import tabulate
2224

2325
try:
2426
from accelerators.utils.torch_profiler import profiler_or_nullcontext
@@ -401,6 +403,16 @@ def collect_kernels_to_profile(kernels: Optional[List[str]]) -> List[QuantizeOpB
401403
return [op for op in quantize_ops if op.name in kernels]
402404

403405

406+
def print_kernels(kernels: Optional[List[str]]) -> List[QuantizeOpBase]:
407+
data = sorted(
408+
[
409+
(op.name, "Yes" if op.cuda else "No", "Yes" if op.hip else "No")
410+
for op in get_quantize_ops()
411+
]
412+
)
413+
print(tabulate(data, headers=["Name", "CUDA", "ROCm"], tablefmt="orgtbl"))
414+
415+
404416
@click.command()
405417
@click.option(
406418
"--output-dir",
@@ -542,12 +554,13 @@ def invoke_main(
542554
if enable_amd_env_vars:
543555
set_amd_env_vars()
544556
# If kernel filter is provided, parse it. Else, benchmark all kernels.
545-
quantize_ops = collect_kernels_to_profile(
546-
kernels.strip().split(",") if kernels else None
547-
)
557+
all_kernels = kernels.strip().split(",") if kernels else None
558+
quantize_ops = collect_kernels_to_profile(all_kernels)
548559

549560
if len(quantize_ops) == 0:
550-
raise Exception("No valid kernels to benchmark.")
561+
print("No valid kernels to benchmark. Available kernels:")
562+
print_kernels(all_kernels)
563+
sys.exit(1)
551564

552565
if num_iters < 1:
553566
print("Warning: Number of iterations must be at least 1.")

0 commit comments

Comments
 (0)