@@ -499,8 +499,9 @@ def generate_mlir_driver_commandline(self, rocmlir_gen_flags, kernel_repeats=MLI
499499 str (self .conv_stride_w ), '--padding_h' ,
500500 str (self .padding_h ), '--padding_w' ,
501501 str (self .padding_w ), '--groupsize' ,
502- str (self .group ), '--kernel-repeats' ,
503- str (kernel_repeats ), f"--perf_config={ self .perfconfig } "
502+ str (self .group ),
503+ * (['--kernel-repeats' , str (kernel_repeats )] if kernel_repeats is not None else []),
504+ f"--perf_config={ self .perfconfig } "
504505 ])
505506 result += ' '
506507 if rocmlir_gen_flags != '' :
@@ -696,7 +697,7 @@ def get_gemm_configurations(filename,
696697
697698 # Skip unsupported datatypes
698699 if datatype == 'f4E2M1FN' :
699- ## TODO: use information from AMDArchDB when it becomes available to determine supported chips
700+ # TODO: use information from AMDArchDB when it becomes available to determine supported chips
700701 supported_chips = {'gfx950' }
701702 if get_chip () not in supported_chips :
702703 continue
@@ -926,32 +927,28 @@ def set_perfconfig(self, perf_config):
926927 self .perfconfig = perf_config
927928
928929 def generate_mlir_driver_commandline (self , rocmlir_gen_flags , kernel_repeats = MLIR_N_REPEATS ):
929- cmd_parts = [
930+ result = ' ' . join ( [
930931 '-operation' , 'gemm' , '-t' , self .datatype , '-out_datatype' , self .out_dtype , '--arch' ,
931932 self .arch , '--num_cu' ,
932933 str (self .num_cu ), '-g' ,
933934 str (self .g ), '-m' ,
934935 str (self .m ), '-k' ,
935936 str (self .k ), '-n' ,
936- str (self .n ), f"-transA={ self .trans_a } " , f"-transB={ self .trans_b } "
937- ]
937+ str (self .n ), f"-transA={ self .trans_a } " , f"-transB={ self .trans_b } " ,
938+ * (['--kernel-repeats' , str (kernel_repeats )] if kernel_repeats is not None else []),
939+ f"--perf_config={ self .perfconfig } "
940+ ])
938941
939942 if self .scaled_gemm :
940- cmd_parts . append ( ' -scaledGemm')
943+ result += ' -scaledGemm'
941944 if self .scale_a_dtype :
942- cmd_parts . extend ([ ' -scale_a_dtype' , self .scale_a_dtype ])
945+ result += f' -scale_a_dtype { self .scale_a_dtype } '
943946 if self .scale_b_dtype :
944- cmd_parts . extend ([ ' -scale_b_dtype' , self .scale_b_dtype ])
947+ result += f' -scale_b_dtype { self .scale_b_dtype } '
945948 if self .trans_scale_a :
946- cmd_parts . append ( f" -transScaleA= { self .trans_scale_a } " )
949+ result += f' -transScaleA { str ( self .trans_scale_a ) } '
947950 if self .trans_scale_b :
948- cmd_parts .append (f"-transScaleB={ self .trans_scale_b } " )
949-
950- cmd_parts .extend (
951- ['--kernel-repeats' ,
952- str (kernel_repeats ), f"--perf_config={ self .perfconfig } " ])
953-
954- result = ' ' .join (cmd_parts )
951+ result += f' -transScaleB { str (self .trans_scale_b )} '
955952
956953 result += ' '
957954 if rocmlir_gen_flags != '' :
@@ -1194,7 +1191,8 @@ def generate_mlir_driver_commandline(self, rocmlir_gen_flags, kernel_repeats=MLI
11941191 f'--dilation_w={ self .dilation_w } ' , f'--conv_stride_h={ self .conv_stride_h } ' ,
11951192 f'--conv_stride_w={ self .conv_stride_w } ' , f'--padding_h={ self .padding_h } ' ,
11961193 f'--padding_w={ self .padding_w } ' , f'--groupsize={ self .group } ' , f'--gemmO={ self .o } ' ,
1197- f'--kernel-repeats={ kernel_repeats } ' , f"--perf_config={ self .perfconfig } "
1194+ * (['--kernel-repeats' , str (kernel_repeats )] if kernel_repeats is not None else []),
1195+ f"--perf_config={ self .perfconfig } "
11981196 ])
11991197 result += ' '
12001198 if rocmlir_gen_flags != '' :
@@ -1363,8 +1361,9 @@ def generate_mlir_driver_commandline(self, rocmlir_gen_flags, kernel_repeats=MLI
13631361 str (self .k ), '-n' ,
13641362 str (self .n ), '-gemmO' ,
13651363 str (self .o ), f"-transA={ self .trans_a } " , f"-transB={ self .trans_b } " ,
1366- f"-transC={ self .trans_c } " , f"-transO={ self .trans_o } " , '--kernel-repeats' ,
1367- str (kernel_repeats ), f"--perf_config={ self .perfconfig } "
1364+ f"-transC={ self .trans_c } " , f"-transO={ self .trans_o } " ,
1365+ * (['--kernel-repeats' , str (kernel_repeats )] if kernel_repeats is not None else []),
1366+ f"--perf_config={ self .perfconfig } "
13681367 ])
13691368 result += ' '
13701369 if rocmlir_gen_flags != '' :
@@ -1691,17 +1690,25 @@ def run_config_with_mlir(config: PerfConfiguration,
16911690 # remove the result file generated by rocprof in previous benchmarking
16921691 if os .path .exists (get_profiler_output_path (arch , BENCHMARKING_STATS_FILE_NAME )):
16931692 os .remove (get_profiler_output_path (arch , BENCHMARKING_STATS_FILE_NAME ))
1694- commandline_options = config .generate_mlir_driver_commandline (rocmlir_gen_flags )
1693+ use_tuning_driver = (not use_rocprof ) and bool (config .perfconfig )
1694+ use_host_harness = not use_tuning_driver
1695+
1696+ rocmlir_gen_flags = rocmlir_gen_flags + ' -ph' if use_host_harness else ''
1697+ # We want to use kernel_repeats only if we are passing ' -ph' to rocmlir-gen, otherwise we use None.
1698+ # This is because the kernel-repeats flag is only supported with host harness or CPU validation.
1699+ kernel_repeats = MLIR_N_REPEATS if use_host_harness else None
1700+
1701+ commandline_options = config .generate_mlir_driver_commandline (rocmlir_gen_flags , kernel_repeats )
1702+ rocmlir_gen_cmd = paths .mlir_paths .rocmlir_gen_path + ' ' + commandline_options
16951703 if debug :
16961704 print ("Running MLIR Benchmark: " , repr (config ))
16971705
16981706 nanoseconds = np .nan
16991707
17001708 # Use HIP timing via tuning-driver if rocprof is disabled and perfconfig is present
1701- if not use_rocprof and config . perfconfig :
1709+ if use_tuning_driver :
17021710 if debug :
17031711 print ("Using HIP timing for benchmarking" )
1704- rocmlir_gen_cmd = paths .mlir_paths .rocmlir_gen_path + ' ' + commandline_options
17051712 tuning_driver_command = [
17061713 paths .mlir_paths .rocmlir_tuning_driver_path , f'--benchmark-config={ config .perfconfig } ' ,
17071714 f'--num-iterations={ MLIR_N_REPEATS } ' , f'--warmup-iterations={ WARMUP_ITERATIONS } ' ,
@@ -1719,7 +1726,6 @@ def run_config_with_mlir(config: PerfConfiguration,
17191726 else :
17201727 if debug :
17211728 print ("Using rocprof for benchmarking" )
1722- rocmlir_gen_cmd = paths .mlir_paths .rocmlir_gen_path + ' -ph ' + commandline_options
17231729 rocmlir_driver_cmd = [paths .mlir_paths .rocmlir_driver_path , '-c' ]
17241730 mlir_cpu_runner_args = [
17251731 f'--shared-libs={ paths .mlir_paths .libmlir_rocm_runtime_path } ,{ paths .mlir_paths .libconv_validation_wrappers_path } ,{ paths .mlir_paths .libmlir_runtime_utils_path } ,{ paths .mlir_paths .libmlir_c_runner_utils_path } ' ,
0 commit comments