Skip to content

Commit 720af5a

Browse files
authored
Parameter Sweeps for Attention: Check all outputs, log failures, avoid kernel repeats (#1914)
Check for all outputs to be [1 1 1], write failing configs to file every time one is detected for both attn and conv, change the usage of kernel_repeats to be 1 when used in parameter sweeps, include currentSeqLen when printing attention configuration. --------- Signed-off-by: Djordje Antic <[email protected]>
1 parent 4423eaf commit 720af5a

File tree

4 files changed

+25
-12
lines changed

4 files changed

+25
-12
lines changed

mlir/utils/jenkins/Jenkinsfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -393,7 +393,7 @@ void check_randomE2ETests(String codepath) {
393393
void parameterSweep(String CONFIG, String codepath) {
394394
timeout(time: 300, activity: true, unit: 'MINUTES') {
395395
dir('build') {
396-
sh """python3 ./bin/parameterSweeps.py -j 5 ${CONFIG}"""
396+
sh """python3 ./bin/parameterSweeps.py -j 5 ${CONFIG} --log-failures"""
397397
}
398398
}
399399
}

mlir/utils/performance/attentionSweeps.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@
3232
# GLOBAL VARIABLES
3333
DATA_TYPES_ATTENTION = initializeDataTypesAttention()
3434
BOOLS = [True, False]
35-
LOGFILE = 'failing_configs.csv'
3635

3736
# Week number is used as seed to make sure weekly CI is reproducible
3837
seed = datetime.utcnow().isocalendar()[1]
@@ -156,7 +155,7 @@ def logFailingConfigs(configs: List[AttentionConfiguration], filename: str):
156155
writer = csv.writer(csvfile)
157156
writer.writerow(['CommandLine'])
158157
for config in configs:
159-
writer.writerow([' '.join(config.generateMlirDriverCommandLine(''))])
158+
writer.writerow([config.generateMlirDriverCommandLine('', kernel_repeats=None)])
160159

161160
def main():
162161
parser = argparse.ArgumentParser(
@@ -181,7 +180,8 @@ def main():
181180
arch=arch,
182181
flags=[],
183182
concurrent_tests=args.jobs,
184-
numCu=getNumCU(chip)
183+
numCu=getNumCU(chip),
184+
logFailures=args.log_failures
185185
)
186186

187187

@@ -206,8 +206,6 @@ def main():
206206
print(f"{'Failing Configurations':^80}\n")
207207
for fail in failing:
208208
print(multilineRepr(fail))
209-
if args.log_failures:
210-
logFailingConfigs(failing, LOGFILE)
211209

212210
print(f"\nPassed: {passed}, Invalid: {invalid}, Failed: {len(failing)}")
213211

mlir/utils/performance/parameterSweeps.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ class Options:
3636
flags: list
3737
concurrent_tests: int
3838
numCu: int
39+
logFailures: bool = False
3940

4041
class PerfConfig:
4142
class Version(enum.Enum):
@@ -191,7 +192,7 @@ async def testConfig(config, options: Options, paths: Paths) -> TestResult:
191192
if isinstance(config, MLIROnlyConfig):
192193
rocmlirGenOpts = config.generateMlirDriverCommandLine(options.flags)
193194
else:
194-
rocmlirGenOpts = config.generateMlirDriverCommandLine(' '.join(options.flags)).split()
195+
rocmlirGenOpts = config.generateMlirDriverCommandLine(' '.join(options.flags), kernel_repeats=None).split()
195196
if getattr(config, "currentSeqLen") is not None:
196197
rocmlirGenOpts.append(f"--current_seq_len={','.join(map(str, config.currentSeqLen))}")
197198
rocmlirGenOpts.append('-pv')
@@ -264,13 +265,17 @@ async def testConfig(config, options: Options, paths: Paths) -> TestResult:
264265
Return code = {runner.returncode}""", file=sys.stderr)
265266
return TestResult.FAIL
266267

267-
if not CORRECT_RESULT_RE.search(runnerOut):
268+
output_lines = [line.strip() for line in runnerOut.splitlines() if len(line.strip()) > 0]
269+
expected_output = "[1 1 1]"
270+
all_correct = all(line == expected_output for line in output_lines)
271+
if not all_correct:
268272
print(f"""Config returned incorrect result
269273
Output = {runnerOut}
270274
Errors = {runnerErrs.decode('utf-8')}""", file=sys.stderr)
271275
return TestResult.FAIL
272276
return TestResult.PASS
273277

278+
274279
IterType = TypeVar('IterType')
275280
def grouper(iterable: Iterable[IterType], n: int):
276281
it = iter(iterable)
@@ -288,8 +293,16 @@ async def dropGoodConfig(config, options: Options, paths: Paths):
288293
if isinstance(config, MLIROnlyConfig):
289294
print(f"{result.name}: {config!r}")
290295
else:
296+
print("-" * 100)
291297
print(f"{result.name}: {multilineRepr(config)}")
292298
if result == TestResult.FAIL:
299+
if options.logFailures:
300+
if isinstance(config, perfRunner.AttentionConfiguration):
301+
with open("failing_attn_configs.txt", "a") as f:
302+
f.write(multilineRepr(config) + "\n")
303+
else:
304+
with open("failing_conv_configs.txt", "a") as f:
305+
f.write(multilineRepr(config) + "\n")
293306
return config
294307
return result
295308

@@ -459,7 +472,7 @@ async def runConfig(paramIter: Iterable[IterType],
459472
if len(failures) != 0:
460473
print("*** Summary of failures ***")
461474
for c in failures:
462-
print(' '.join(c.generateMlirDriverCommandLine(options.flags)))
475+
print(' '.join(c.generateMlirDriverCommandLine(options.flags, kernel_repeats=None)))
463476
print(f"Passed: {n_passes}, Invalid: {n_invalids}, Failed: {len(failures)}")
464477
return len(failures) == 0
465478

@@ -482,6 +495,8 @@ def main() -> bool:
482495
help='Use xdlops when generating kernels (default off)')
483496
parser.add_argument('--no-xdlops', '-X', dest='xdlops', action='store_false',
484497
help='Explicitly disable xdlops usage')
498+
parser.add_argument('--log-failures', '-L', action='store_true', default=False,
499+
help='Save failures to file')
485500
parser.add_argument(
486501
'--codepath',
487502
type=str,
@@ -527,7 +542,7 @@ def main() -> bool:
527542
# unknow arch info
528543
print(f"""Unknown arch {arch}""", file=sys.stderr)
529544

530-
options = Options(debug=args.debug, quiet=args.quiet,
545+
options = Options(debug=args.debug, quiet=args.quiet, logFailures=args.log_failures,
531546
arch=arch, flags=rocmlir_gen_flags, concurrent_tests=args.jobs, numCu=getNumCU(perfRunner.getChip()))
532547

533548
paths = perfRunner.create_paths(None, args.mlir_build_dir)

mlir/utils/performance/perfRunner.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1337,7 +1337,7 @@ def tableEntry(self, nanoSeconds):
13371337
def setPerfConfig(self, perf_config):
13381338
self.perfConfig = perf_config
13391339

1340-
def generateMlirDriverCommandLine(self, rocmlir_gen_flags):
1340+
def generateMlirDriverCommandLine(self, rocmlir_gen_flags, kernel_repeats=MLIR_N_REPEATS):
13411341
result = ' '.join(['-operation', 'attention',
13421342
'-t', self.dataType,
13431343
'--arch', self.arch,
@@ -1357,7 +1357,7 @@ def generateMlirDriverCommandLine(self, rocmlir_gen_flags):
13571357
f"-transO={self.transO}",
13581358
f"-causal={self.causal}",
13591359
f"-return_lse={self.return_lse}",
1360-
'--kernel-repeats', str(MLIR_N_REPEATS),
1360+
*(['--kernel-repeats', str(kernel_repeats)] if kernel_repeats is not None else []),
13611361
f"--perf_config={self.perfConfig}"])
13621362
result += ' '
13631363
if rocmlir_gen_flags != '':

0 commit comments

Comments
 (0)