@@ -36,6 +36,7 @@ class Options:
3636 flags : list
3737 concurrent_tests : int
3838 numCu : int
39+ logFailures : bool = False
3940
4041class PerfConfig :
4142 class Version (enum .Enum ):
@@ -191,7 +192,7 @@ async def testConfig(config, options: Options, paths: Paths) -> TestResult:
191192 if isinstance (config , MLIROnlyConfig ):
192193 rocmlirGenOpts = config .generateMlirDriverCommandLine (options .flags )
193194 else :
194- rocmlirGenOpts = config .generateMlirDriverCommandLine (' ' .join (options .flags )).split ()
195+ rocmlirGenOpts = config .generateMlirDriverCommandLine (' ' .join (options .flags ), kernel_repeats = None ).split ()
195196 if getattr (config , "currentSeqLen" ) is not None :
196197 rocmlirGenOpts .append (f"--current_seq_len={ ',' .join (map (str , config .currentSeqLen ))} " )
197198 rocmlirGenOpts .append ('-pv' )
@@ -264,13 +265,17 @@ async def testConfig(config, options: Options, paths: Paths) -> TestResult:
264265Return code = { runner .returncode } """ , file = sys .stderr )
265266 return TestResult .FAIL
266267
267- if not CORRECT_RESULT_RE .search (runnerOut ):
268+ output_lines = [line .strip () for line in runnerOut .splitlines () if len (line .strip ()) > 0 ]
269+ expected_output = "[1 1 1]"
270+ all_correct = all (line == expected_output for line in output_lines )
271+ if not all_correct :
268272 print (f"""Config returned incorrect result
269273Output = { runnerOut }
270274Errors = { runnerErrs .decode ('utf-8' )} """ , file = sys .stderr )
271275 return TestResult .FAIL
272276 return TestResult .PASS
273277
278+
274279IterType = TypeVar ('IterType' )
275280def grouper (iterable : Iterable [IterType ], n : int ):
276281 it = iter (iterable )
@@ -288,8 +293,16 @@ async def dropGoodConfig(config, options: Options, paths: Paths):
288293 if isinstance (config , MLIROnlyConfig ):
289294 print (f"{ result .name } : { config !r} " )
290295 else :
296+ print ("-" * 100 )
291297 print (f"{ result .name } : { multilineRepr (config )} " )
292298 if result == TestResult .FAIL :
299+ if options .logFailures :
300+ if isinstance (config , perfRunner .AttentionConfiguration ):
301+ with open ("failing_attn_configs.txt" , "a" ) as f :
302+ f .write (multilineRepr (config ) + "\n " )
303+ else :
304+ with open ("failing_conv_configs.txt" , "a" ) as f :
305+ f .write (multilineRepr (config ) + "\n " )
293306 return config
294307 return result
295308
@@ -459,7 +472,7 @@ async def runConfig(paramIter: Iterable[IterType],
459472 if len (failures ) != 0 :
460473 print ("*** Summary of failures ***" )
461474 for c in failures :
462- print (' ' .join (c .generateMlirDriverCommandLine (options .flags )))
475+ print (' ' .join (c .generateMlirDriverCommandLine (options .flags , kernel_repeats = None )))
463476 print (f"Passed: { n_passes } , Invalid: { n_invalids } , Failed: { len (failures )} " )
464477 return len (failures ) == 0
465478
@@ -482,6 +495,8 @@ def main() -> bool:
482495 help = 'Use xdlops when generating kernels (default off)' )
483496 parser .add_argument ('--no-xdlops' , '-X' , dest = 'xdlops' , action = 'store_false' ,
484497 help = 'Explicitly disable xdlops usage' )
498+ parser .add_argument ('--log-failures' , '-L' , action = 'store_true' , default = False ,
499+ help = 'Save failures to file' )
485500 parser .add_argument (
486501 '--codepath' ,
487502 type = str ,
@@ -527,7 +542,7 @@ def main() -> bool:
527542 # unknow arch info
528543 print (f"""Unknown arch { arch } """ , file = sys .stderr )
529544
530- options = Options (debug = args .debug , quiet = args .quiet ,
545+ options = Options (debug = args .debug , quiet = args .quiet , logFailures = args . log_failures ,
531546 arch = arch , flags = rocmlir_gen_flags , concurrent_tests = args .jobs , numCu = getNumCU (perfRunner .getChip ()))
532547
533548 paths = perfRunner .create_paths (None , args .mlir_build_dir )
0 commit comments