99import torch
1010import triton
1111import triton .language as tl
12-
13- from streamk_kernel import streamk_gemm
12+ import importlib
1413
1514from datetime import datetime
1615import multiprocessing
@@ -206,9 +205,9 @@ def profile_batch_kernels(M, N, K, gpuid, gpus, jobs, verbose):
206205 os .environ ['ROCR_VISIBLE_DEVICES' ] = str (gpuid )
207206 jobId = gpuIdx
208207 while jobId < jobs :
209- kernel_name = get_filename_profile_driver (M , N , K , jobId )
208+ kernelname = get_filename_profile_driver (M , N , K , jobId )
210209 if verbose :
211- print (f"profiling { kernel_name } on GPU { gpuid } " )
210+ print (f"profiling { kernelname } on GPU { gpuid } " )
212211 run_bash_command_wrapper (
213212 f"rocprof --stats -o results_{ jobId } .csv python { get_filename_profile_driver (M , N , K , jobId )} " ,
214213 # f"rocprofv2 --plugin file --plugin-version 1 --kernel-trace -o {jobId} python {get_filename_profile_driver(M, N, K, jobId)}",
@@ -217,15 +216,15 @@ def profile_batch_kernels(M, N, K, gpuid, gpus, jobs, verbose):
217216
218217
219218def tune_gemm_config (M , N , K , col_a , col_b , dtype_a , dtype_b , dtype_c , dtype_p , dtype_lock , init_type , configs ,
220- run_bench , jobs , iters , skipWarmup , verbose = 0 , num_threads = 32 , gpus = [0 ], rotating_buffer_size = 256 ,
221- bias_size = 0 , icache_flush = False ):
219+ run_bench , jobs , iters , skipWarmup , module_name , kernel_name , verbose = 0 , num_threads = 32 , gpus = [0 ],
220+ rotating_buffer_size = 256 , bias_size = 0 , icache_flush = False ):
222221
223222 # precompile the kernels in parallel
224223 start_time = datetime .now ()
225224 if not skipWarmup :
226225 # Generate kernel out of all configs
227226 fname = generate_compile_driver (M , N , K , col_a , col_b , dtype_a , dtype_b , dtype_c , dtype_p , dtype_lock ,
228- init_type , configs , rotating_buffer_size , bias_size )
227+ init_type , configs , rotating_buffer_size , bias_size , kernel_name )
229228
230229 run_bash_command (f"python { fname } -n { num_threads } " , capture = (verbose < 2 ))
231230 compile_end = datetime .now ()
@@ -235,7 +234,8 @@ def tune_gemm_config(M, N, K, col_a, col_b, dtype_a, dtype_b, dtype_c, dtype_p,
235234
236235 # Generate kernels out of all configs
237236 generate_profile_tasks (M , N , K , col_a , col_b , dtype_a , dtype_b , dtype_c , dtype_p , dtype_lock , init_type , configs ,
238- jobs , iters , run_bench , rotating_buffer_size , bias_size , icache_flush )
237+ jobs , iters , run_bench , rotating_buffer_size , bias_size , icache_flush , module_name ,
238+ kernel_name )
239239
240240 # profile generated kernels
241241 running = [
@@ -377,8 +377,8 @@ def gen_rotating_tensors(M, N, K, dtype_a, need_Trans_a, dtype_b, need_Trans_b,
377377 return in_outs
378378
379379
380- def matmul (a , b , c , bias , P , locks , num_sms , block_m , block_n , block_k , group_m , num_warps , num_stages , waves_per_eu ,
381- mfmaInstrSize , kpack , use_bias ):
380+ def matmul (kernel_func , a , b , c , bias , P , locks , num_sms , block_m , block_n , block_k , group_m , num_warps , num_stages ,
381+ waves_per_eu , mfmaInstrSize , kpack , use_bias ):
382382 # Check constraints.
383383 assert a .shape [1 ] == b .shape [0 ], "Incompatible dimensions"
384384 #assert a.is_contiguous(), "Matrix A must be contiguous"
@@ -396,7 +396,7 @@ def matmul(a, b, c, bias, P, locks, num_sms, block_m, block_n, block_k, group_m,
396396 streamk_tiles = m_tiles * n_tiles % num_sms
397397 # change num_xcds = 1 if using MI250
398398 num_xcds = 8
399- streamk_gemm [
399+ kernel_func [
400400 grid ,
401401 ](a , b , c , bias , P , locks , M , N , K , a .stride (0 ), a .stride (1 ), b .stride (0 ), b .stride (1 ), c .stride (0 ), c .stride (1 ),
402402 stride_bias = stride_bias , BLOCK_SIZE_M = block_m , BLOCK_SIZE_N = block_n , BLOCK_SIZE_K = block_k , GROUP_SIZE_M = group_m ,
@@ -405,7 +405,8 @@ def matmul(a, b, c, bias, P, locks, num_sms, block_m, block_n, block_k, group_m,
405405 return c
406406
407407
408- def test_correctness (M , N , K , col_a , col_b , dtype_a , dtype_b , dtype_c , init_type , config , bias_vector , verbose ):
408+ def test_correctness (kernel_func , M , N , K , col_a , col_b , dtype_a , dtype_b , dtype_c , init_type , config , bias_vector ,
409+ verbose ):
409410 block_m , block_n , block_k , group_m , num_sms , num_warps , num_stages , waves_per_eu , mfmaInstrSize , kpack = read_config (
410411 config )
411412 use_bias = bias_vector
@@ -423,8 +424,8 @@ def test_correctness(M, N, K, col_a, col_b, dtype_a, dtype_b, dtype_c, init_type
423424 c = torch .zeros ((M , N ), device = a .device , dtype = tl_to_torch_types [name_to_tl_types [dtype_c ]])
424425 locks = torch .zeros ((num_sms , ), device = "cuda" , dtype = torch .int32 )
425426 P = torch .zeros ((num_sms , block_m * block_n ), device = "cuda" , dtype = torch .float32 )
426- triton_output = matmul (a , b , c , bias , P , locks , num_sms , block_m , block_n , block_k , group_m , num_warps , num_stages ,
427- waves_per_eu , mfmaInstrSize , kpack , use_bias )
427+ triton_output = matmul (kernel_func , a , b , c , bias , P , locks , num_sms , block_m , block_n , block_k , group_m , num_warps ,
428+ num_stages , waves_per_eu , mfmaInstrSize , kpack , use_bias )
428429 torch_output = torch .matmul (a_fp16 , b_fp16 )
429430 if use_bias :
430431 torch_output += bias_fp16 [:, None ]
@@ -435,7 +436,7 @@ def test_correctness(M, N, K, col_a, col_b, dtype_a, dtype_b, dtype_c, init_type
435436 size_str = ''
436437 if verbose :
437438 size_str = f'SIZE M: { M } , N: { N } , K: { K } , trans: { row_a_str } { row_b_str } '
438- print (f'{ size_str } correctness check' )
439+ print (f'{ kernel_func } { size_str } correctness check' )
439440 torch .testing .assert_close (triton_output .to (torch .float16 ), torch_output , atol = atol , rtol = rtol )
440441 print (f'{ size_str } Correct✅' )
441442
@@ -446,6 +447,8 @@ def parse_args():
446447 allow_abbrev = False ,
447448 )
448449
450+ parser .add_argument ('--kernel' , default = 'streamk_kernel, streamk_gemm' ,
451+ help = 'can specify different kernel file name' )
449452 parser .add_argument ("-m" , type = int , default = 0 )
450453 parser .add_argument ("-n" , type = int , default = 0 )
451454 parser .add_argument ("-k" , type = int , default = 0 )
@@ -486,11 +489,6 @@ def parse_args():
486489 parser .add_argument ("--hack_triton_compiler" , action = 'store_true' , default = False ,
487490 help = "Modify the triton source to avoid backend query" )
488491 args = parser .parse_args ()
489- if not args .o :
490- if args .benchmark :
491- args .o = "benchmarking_results.csv"
492- else :
493- args .o = get_default_tuning_result_filename ()
494492
495493 return args
496494
@@ -542,6 +540,19 @@ def get_rocm_version():
542540
543541def main ():
544542 args = parse_args ()
543+ # parse kernel file and kernel function name
544+ module_name , kernel_name = args .kernel .split (',' )
545+ module_name = module_name .strip ()
546+ kernel_name = kernel_name .strip ()
547+ module = importlib .import_module (module_name )
548+ kernel_func = getattr (module , kernel_name )
549+
550+ if not args .o :
551+ if args .benchmark :
552+ args .o = f"benchmarking_results_{ kernel_name } .csv"
553+ else :
554+ args .o = get_default_tuning_result_filename (kernel_name )
555+
545556 matrix_size_file = args .gemm_size_file
546557 output_file = args .o
547558 keepTmp = args .keep
@@ -613,7 +624,8 @@ def main():
613624 for (M , N , K , col_a , col_b , myConfig ) in mnks :
614625 if myConfig is None :
615626 raise Exception ("kernel config is None, need to provide a tuning config" )
616- test_correctness (M , N , K , col_a , col_b , dtype_a , dtype_b , dtype_c , init_type , myConfig , bias_vector , True )
627+ test_correctness (kernel_func , M , N , K , col_a , col_b , dtype_a , dtype_b , dtype_c , init_type , myConfig ,
628+ bias_vector , True )
617629 return
618630
619631 configs_full = get_full_tuning_space ()
@@ -658,7 +670,7 @@ def main():
658670 configs += delta_configs
659671
660672 ## Append new configs into the tuning space
661- generate_matmul_kernels (delta_configs )
673+ generate_matmul_kernels (delta_configs , module_name , kernel_name )
662674
663675 row_a_str = 'N' if col_a else 'T'
664676 row_b_str = 'N' if col_b else 'T'
@@ -679,8 +691,9 @@ def main():
679691 bias_size = M if bias_vector else 0
680692 minTime , bestConfig , compile_time , profile_time , post_time = tune_gemm_config (
681693 M , N , K , col_a , col_b , dtype_a , dtype_b , dtype_c , dtype_p , dtype_lock , init_type , pruned_configs , run_bench ,
682- jobs , iters , skipWarmup , num_threads = args .num_threads , gpus = gpus , verbose = verbose_level ,
683- rotating_buffer_size = rotating_buffer_size , bias_size = bias_size , icache_flush = icache_flush )
694+ jobs , iters , skipWarmup , module_name , kernel_name , num_threads = args .num_threads , gpus = gpus ,
695+ verbose = verbose_level , rotating_buffer_size = rotating_buffer_size , bias_size = bias_size ,
696+ icache_flush = icache_flush )
684697
685698 # post processing the numbers
686699 perf_tflops = lambda us : 2 * M * N * K * 1e-12 / (us * 1e-6 )
@@ -701,9 +714,9 @@ def main():
701714
702715 sizeDict = {'M' : M , 'N' : N , 'K' : K , 'rowMajorA' : row_a_str , 'rowMajorB' : row_b_str }
703716 sizeDict .update (bestConfig )
717+ sizeDict .update ({'TFLOPS' : formatted_tflops , 'time(us)' : minTime })
704718 if not run_bench :
705719 f_results .write ("- " + str (sizeDict ) + " " )
706- f_results .write (f'# TFLOPS: { formatted_tflops } time(us): { minTime } \n ' )
707720
708721 # remove generated files if asked to
709722 if not keepTmp :
0 commit comments