@@ -335,36 +335,34 @@ def get_type(provider):
335335 plot_name = "matmul-performance" ,
336336 args = {},
337337 ))
338- def benchmark (M , N , K , provider , model = None ):
338+ def benchmark (M , N , K , provider , model = None , args = None ):
339339 in_dtype_a , in_dtype_b = [name_to_torch_types [x ] for x in get_type (provider )]
340340 out_dtype = in_dtype_a
341341
342342 quantiles = [0.5 , 0.2 , 0.8 ]
343+ layout_tn = args .layout == 'tn'
344+ a , _ , a_scale = gen_input (M , K , in_dtype_a , False , 1 , device = 'cuda' )
345+ b , _ , b_scale = gen_input (K , N , in_dtype_b , layout_tn , 2 , device = 'cuda' )
343346 if 'hipblaslt' in provider :
344- a = torch .randn ((M , K ), dtype = in_dtype_a , device = 'cuda' )
345- b = torch .randn ((N , K ), dtype = in_dtype_b , device = 'cuda' )
346- b = b .T
347-
348347 ms , min_ms , max_ms = triton .testing .do_bench (lambda : torch .matmul (a , b ), quantiles = quantiles )
349348 else : # triton, different data types
350349 assert "triton" in provider
351- a , _ , a_scale = gen_input (M , K , in_dtype_a , False , 1 , device = 'cuda' )
352- b , _ , b_scale = gen_input (K , N , in_dtype_b , True , 2 , device = 'cuda' )
353350 # Allocates output.
354351 c = torch .empty ((M , N ), device = a .device , dtype = out_dtype )
352+
355353 scale_a8_b8 = dtype_is_8_bit (in_dtype_a ) or dtype_is_8_bit (in_dtype_b )
356354 ms , min_ms , max_ms = triton .testing .do_bench (
357355 lambda : matmul (a , b , c , a_scale , b_scale , scale_a8_b8 = scale_a8_b8 , activation = "" ), quantiles = quantiles )
358- global verbose
359- if verbose :
360- print ( f'SIZE: { M } , { N } , { K } Best tuning config: ({ matmul_kernel .best_config () } ) ' )
356+ if args . v :
357+ print ( f'Best tuning config for M= { M } , N= { N } , K= { K } , '
358+ f'dtype= { in_dtype_a } / { in_dtype_b } / { out_dtype } : \n ({ matmul_kernel .best_config } ) \n ' )
361359 perf = lambda ms : 2 * M * N * K * 1e-12 / (ms * 1e-3 )
362360 return perf (ms ), perf (max_ms ), perf (min_ms )
363361
364362
365363def parse_args ():
366364 parser = argparse .ArgumentParser (
367- prog = "GEMM tutorial example " ,
365+ prog = "AMD Triton GEMM kernel " ,
368366 allow_abbrev = False ,
369367 )
370368
@@ -375,48 +373,71 @@ def parse_args():
375373 "Model name to benchmark. Select from: [" + ", " .join (available_models ) +
376374 "]. Use 'all' to benchmark all models. Not providing runs the default benchmark script with custom configs." )
377375 parser .add_argument ('-model' , type = str , default = None , help = model_help )
378- parser .add_argument ('-b' , type = int , default = 0 , help = "Batch size used together with model." )
379- parser .add_argument ('-sq' , type = int , default = 0 , help = "Sequence length used together with model." )
380-
381376 parser .add_argument ("-v" , action = 'store_true' , default = False , help = "Print out the best tuning config" )
382377 parser .add_argument ("-M" , type = int , default = 0 )
383378 parser .add_argument ("-N" , type = int , default = 0 )
384379 parser .add_argument ("-K" , type = int , default = 0 )
380+ parser .add_argument ("-layout" , type = str , default = 'tn' )
381+ parser .add_argument ("-dtype" , type = str , default = None , help = "Data type of inputs and outputs" )
382+ parser .add_argument ("-b_dtype" , type = str , default = None ,
383+ help = "Data type of B operand, if specified (else same as dtype)" )
385384
386385 args = parser .parse_args ()
387386
388387 return args
389388
390389
390+ def get_line_vals_names (a_dtype = None , b_dtype = None ):
391+ line_vals = [
392+ 'hipblaslt(fp16/fp16)' , 'hipblaslt(bf16/bf16)' , 'triton(fp16/fp16)' , 'triton(bf16/bf16)' , 'triton(int8/int8)' ,
393+ 'triton(fp8e4/fp8e4)' , 'triton(fp8e5/fp8e5)' , 'triton(fp16/fp8e4)' , 'triton(fp16/fp8e5)'
394+ ]
395+ line_names = [
396+ "rocBLAS.Fp16" , "rocBLAS.Bf16" , "Triton.Fp16" , "Triton.Bf16" , "Triton.Int8" , "Triton.Fp8E4" , "Triton.Fp8E5" ,
397+ "Triton.Fp16.Fp8E4" , "Triton.Fp16.Fp8E5"
398+ ]
399+ assert not ((a_dtype is None ) ^ (b_dtype is None ))
400+ if a_dtype is not None :
401+ line_vals_suffix_str = '(' + a_dtype + '/' + b_dtype + ')'
402+ line_names_suffix_str = '.' + a_dtype + '.' + b_dtype
403+ line_vals = ['triton' + line_vals_suffix_str ]
404+ line_names = ['Triton' + line_names_suffix_str ]
405+ if not dtype_is_8_bit (name_to_torch_types [a_dtype ]) and \
406+ not dtype_is_8_bit (name_to_torch_types [b_dtype ]):
407+ line_vals += ['hipblaslt' + line_vals_suffix_str ]
408+ line_names += ['hipblaslt' + line_names_suffix_str ]
409+
410+ return line_vals , line_names
411+
412+
391413def main ():
392- # assign to a global verbose var to indicate whether print
393- # best tuning config
394- global verbose
395414 args = parse_args ()
396- verbose = args .v
397415
398416 if args .model :
399417 config_file = args .model_configs
400418 configs = get_model_configs (config_path = config_file , model_families = ["llama3" ], model = args .model )
401419 mnk_list = []
402- batch_size = args .b if args .b else 1
403420
404421 for model_name , config in configs .items ():
405- seq_len = args .sq if args .sq else 4096
406- M , N , K = batch_size * seq_len , config ["hidden_size" ], config ["intermediate_size" ]
422+ M , N , K = args .M or 8192 , config ["hidden_size" ], config ["intermediate_size" ]
407423 mnk_list .append ((model_name , M , N , K ))
408424
409425 benchmark .benchmarks .x_names = ['model' , 'M' , 'N' , 'K' ]
410426 benchmark .benchmarks .x_vals = mnk_list
411427
412- if args .M or args .N or args .K :
413- assert args .model is None , "Providing both -model and -M/N/K is not compatible! -model already fixes -M/N/K."
428+ a_dtype = args .dtype
429+ b_dtype = args .b_dtype or args .dtype
430+ assert a_dtype is None or a_dtype in name_to_torch_types , f"Unsupported dtype { a_dtype } "
431+ assert b_dtype is None or b_dtype in name_to_torch_types , f"Unsupported dtype { b_dtype } "
432+ benchmark .benchmarks .line_vals , benchmark .benchmarks .line_names = get_line_vals_names (a_dtype , b_dtype )
433+ if args .N or args .K :
434+ assert args .model is None , "Providing both -model and N/K is not compatible! -model already fixes N/K."
414435
415436 if args .M and args .N and args .K :
416437 x_vals = [(args .M , args .N , args .K )]
417438 benchmark .benchmarks .x_vals = x_vals
418439
419- benchmark .run (show_plots = True , print_data = True )
440+ benchmark .run (show_plots = True , print_data = True , args = args )
420441
421442
422443if __name__ == '__main__' :
0 commit comments