Skip to content

Commit b4924f1

Browse files
authored
Vinayak/gemm benchmarking (#746)
Add ability to specify dtype as cmd line arg
1 parent 46ca63c commit b4924f1

File tree

1 file changed

+45
-24
lines changed

1 file changed

+45
-24
lines changed

python/perf-kernels/gemm.py

Lines changed: 45 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -335,36 +335,34 @@ def get_type(provider):
335335
plot_name="matmul-performance",
336336
args={},
337337
))
338-
def benchmark(M, N, K, provider, model=None):
338+
def benchmark(M, N, K, provider, model=None, args=None):
339339
in_dtype_a, in_dtype_b = [name_to_torch_types[x] for x in get_type(provider)]
340340
out_dtype = in_dtype_a
341341

342342
quantiles = [0.5, 0.2, 0.8]
343+
layout_tn = args.layout == 'tn'
344+
a, _, a_scale = gen_input(M, K, in_dtype_a, False, 1, device='cuda')
345+
b, _, b_scale = gen_input(K, N, in_dtype_b, layout_tn, 2, device='cuda')
343346
if 'hipblaslt' in provider:
344-
a = torch.randn((M, K), dtype=in_dtype_a, device='cuda')
345-
b = torch.randn((N, K), dtype=in_dtype_b, device='cuda')
346-
b = b.T
347-
348347
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b), quantiles=quantiles)
349348
else: # triton, different data types
350349
assert "triton" in provider
351-
a, _, a_scale = gen_input(M, K, in_dtype_a, False, 1, device='cuda')
352-
b, _, b_scale = gen_input(K, N, in_dtype_b, True, 2, device='cuda')
353350
# Allocates output.
354351
c = torch.empty((M, N), device=a.device, dtype=out_dtype)
352+
355353
scale_a8_b8 = dtype_is_8_bit(in_dtype_a) or dtype_is_8_bit(in_dtype_b)
356354
ms, min_ms, max_ms = triton.testing.do_bench(
357355
lambda: matmul(a, b, c, a_scale, b_scale, scale_a8_b8=scale_a8_b8, activation=""), quantiles=quantiles)
358-
global verbose
359-
if verbose:
360-
print(f'SIZE: {M},{N},{K} Best tuning config: ({matmul_kernel.best_config()})')
356+
if args.v:
357+
print(f'Best tuning config for M={M}, N={N}, K={K}, '
358+
f'dtype={in_dtype_a} / {in_dtype_b} / {out_dtype}: \n({matmul_kernel.best_config})\n')
361359
perf = lambda ms: 2 * M * N * K * 1e-12 / (ms * 1e-3)
362360
return perf(ms), perf(max_ms), perf(min_ms)
363361

364362

365363
def parse_args():
366364
parser = argparse.ArgumentParser(
367-
prog="GEMM tutorial example",
365+
prog="AMD Triton GEMM kernel",
368366
allow_abbrev=False,
369367
)
370368

@@ -375,48 +373,71 @@ def parse_args():
375373
"Model name to benchmark. Select from: [" + ", ".join(available_models) +
376374
"]. Use 'all' to benchmark all models. Not providing runs the default benchmark script with custom configs.")
377375
parser.add_argument('-model', type=str, default=None, help=model_help)
378-
parser.add_argument('-b', type=int, default=0, help="Batch size used together with model.")
379-
parser.add_argument('-sq', type=int, default=0, help="Sequence length used together with model.")
380-
381376
parser.add_argument("-v", action='store_true', default=False, help="Print out the best tuning config")
382377
parser.add_argument("-M", type=int, default=0)
383378
parser.add_argument("-N", type=int, default=0)
384379
parser.add_argument("-K", type=int, default=0)
380+
parser.add_argument("-layout", type=str, default='tn')
381+
parser.add_argument("-dtype", type=str, default=None, help="Data type of inputs and outputs")
382+
parser.add_argument("-b_dtype", type=str, default=None,
383+
help="Data type of B operand, if specified (else same as dtype)")
385384

386385
args = parser.parse_args()
387386

388387
return args
389388

390389

390+
def get_line_vals_names(a_dtype=None, b_dtype=None):
391+
line_vals = [
392+
'hipblaslt(fp16/fp16)', 'hipblaslt(bf16/bf16)', 'triton(fp16/fp16)', 'triton(bf16/bf16)', 'triton(int8/int8)',
393+
'triton(fp8e4/fp8e4)', 'triton(fp8e5/fp8e5)', 'triton(fp16/fp8e4)', 'triton(fp16/fp8e5)'
394+
]
395+
line_names = [
396+
"rocBLAS.Fp16", "rocBLAS.Bf16", "Triton.Fp16", "Triton.Bf16", "Triton.Int8", "Triton.Fp8E4", "Triton.Fp8E5",
397+
"Triton.Fp16.Fp8E4", "Triton.Fp16.Fp8E5"
398+
]
399+
assert not ((a_dtype is None) ^ (b_dtype is None))
400+
if a_dtype is not None:
401+
line_vals_suffix_str = '(' + a_dtype + '/' + b_dtype + ')'
402+
line_names_suffix_str = '.' + a_dtype + '.' + b_dtype
403+
line_vals = ['triton' + line_vals_suffix_str]
404+
line_names = ['Triton' + line_names_suffix_str]
405+
if not dtype_is_8_bit(name_to_torch_types[a_dtype]) and \
406+
not dtype_is_8_bit(name_to_torch_types[b_dtype]):
407+
line_vals += ['hipblaslt' + line_vals_suffix_str]
408+
line_names += ['hipblaslt' + line_names_suffix_str]
409+
410+
return line_vals, line_names
411+
412+
391413
def main():
392-
# assign to a global verbose var to indicate whether print
393-
# best tuning config
394-
global verbose
395414
args = parse_args()
396-
verbose = args.v
397415

398416
if args.model:
399417
config_file = args.model_configs
400418
configs = get_model_configs(config_path=config_file, model_families=["llama3"], model=args.model)
401419
mnk_list = []
402-
batch_size = args.b if args.b else 1
403420

404421
for model_name, config in configs.items():
405-
seq_len = args.sq if args.sq else 4096
406-
M, N, K = batch_size * seq_len, config["hidden_size"], config["intermediate_size"]
422+
M, N, K = args.M or 8192, config["hidden_size"], config["intermediate_size"]
407423
mnk_list.append((model_name, M, N, K))
408424

409425
benchmark.benchmarks.x_names = ['model', 'M', 'N', 'K']
410426
benchmark.benchmarks.x_vals = mnk_list
411427

412-
if args.M or args.N or args.K:
413-
assert args.model is None, "Providing both -model and -M/N/K is not compatible! -model already fixes -M/N/K."
428+
a_dtype = args.dtype
429+
b_dtype = args.b_dtype or args.dtype
430+
assert a_dtype is None or a_dtype in name_to_torch_types, f"Unsupported dtype {a_dtype}"
431+
assert b_dtype is None or b_dtype in name_to_torch_types, f"Unsupported dtype {b_dtype}"
432+
benchmark.benchmarks.line_vals, benchmark.benchmarks.line_names = get_line_vals_names(a_dtype, b_dtype)
433+
if args.N or args.K:
434+
assert args.model is None, "Providing both -model and N/K is not compatible! -model already fixes N/K."
414435

415436
if args.M and args.N and args.K:
416437
x_vals = [(args.M, args.N, args.K)]
417438
benchmark.benchmarks.x_vals = x_vals
418439

419-
benchmark.run(show_plots=True, print_data=True)
440+
benchmark.run(show_plots=True, print_data=True, args=args)
420441

421442

422443
if __name__ == '__main__':

0 commit comments

Comments
 (0)