|
| 1 | +#!/usr/bin/env python3 |
| 2 | +"""Sweeps the parameters of the rocmlir driver for bugs for attention-based kernel configurations. |
| 3 | +
|
| 4 | +Usage: |
| 5 | + python3 attentionSweeps.py --mlir-build-dir <path-to-mlir-build-dir> [options] |
| 6 | +
|
| 7 | +Options: |
| 8 | + --mlir-build-dir Path to the MLIR build directory (default: auto-detected) |
| 9 | + --samples Number of random configuration samples to the test (default: 1000) |
| 10 | + --jobs Number of concurrent tests to run in parallel (default: os.cpu_count()) |
| 11 | + --debug Enable debug output |
| 12 | + --quiet Disable per-test result output |
| 13 | + --log-failures Save failing configurations to csv file |
| 14 | +""" |
| 15 | +import argparse |
| 16 | +import itertools |
| 17 | +import asyncio |
| 18 | +from typing import Iterable, List, TypeVar |
| 19 | +from datetime import datetime |
| 20 | +import sys |
| 21 | +import csv |
| 22 | +import random |
| 23 | +import os |
| 24 | + |
| 25 | +from perfRunner import AttentionConfiguration |
| 26 | +from perfRunner import getArch, getNumCU, initializeDataTypesAttention |
| 27 | +from perfRunner import create_paths as createPaths |
| 28 | +from perfRunner import find_mlir_build_dir as findMlirBuildDir |
| 29 | +from perfRunner import DATA_TYPES_ATTENTION, GFX_CHIP_RE |
| 30 | +from parameterSweeps import Options, sweepParameters, multilineRepr |
| 31 | + |
| 32 | +# GLOBAL VARIABLES |
| 33 | +DATA_TYPES_ATTENTION = initializeDataTypesAttention() |
| 34 | +BOOLS = [True, False] |
| 35 | +LOGFILE = 'failing_configs.csv' |
| 36 | + |
| 37 | +# Week number is used as seed to make sure weekly CI is reproducible |
| 38 | +seed = datetime.utcnow().isocalendar()[1] |
| 39 | +random.seed(seed) |
| 40 | + |
| 41 | +def toAttentionConfig(params, options: Options) -> AttentionConfiguration: |
| 42 | + """Converts a sampled parameter tuple into a AttentionConfiguration instance.""" |
| 43 | + shape, perf = params |
| 44 | + *shapeParams, currentSeqLen = shape |
| 45 | + dtype, g, slq, slk, nhq, nhkv, hdqk, hdv, scale, bias, tq, tk, tv, to, causal, rlse = shapeParams |
| 46 | + perfString = f"attn:v1:{','.join(str(x) for x in perf)}" |
| 47 | + attnConfig = AttentionConfiguration( |
| 48 | + dtype=dtype, |
| 49 | + g=g, |
| 50 | + seq_len_q=slq, |
| 51 | + seq_len_k=slk, |
| 52 | + num_heads_q=nhq, |
| 53 | + num_heads_kv=nhkv, |
| 54 | + head_dim_qk=hdqk, |
| 55 | + head_dim_v=hdv, |
| 56 | + with_attn_scale=scale, |
| 57 | + with_attn_bias=bias, |
| 58 | + transQ=tq, |
| 59 | + transK=tk, |
| 60 | + transV=tv, |
| 61 | + transO=to, |
| 62 | + causal=causal, |
| 63 | + return_lse=rlse, |
| 64 | + arch=options.arch, |
| 65 | + numCU=options.numCu, |
| 66 | + perf_config=perfString |
| 67 | + ) |
| 68 | + attnConfig.currentSeqLen = currentSeqLen |
| 69 | + return attnConfig |
| 70 | + |
| 71 | +IterType = TypeVar('IterType') |
| 72 | +def grouper(iterable: Iterable[IterType], n: int): |
| 73 | + it = iter(iterable) |
| 74 | + while True: |
| 75 | + chunk = tuple(itertools.islice(it, n)) |
| 76 | + if not chunk: |
| 77 | + return |
| 78 | + yield chunk |
| 79 | + |
| 80 | +def genCurrentSeqLens(g: int, maxSeqLen: int) -> list[int]: |
| 81 | + return [random.randint(0, maxSeqLen-1) for _ in range(g)] |
| 82 | + |
| 83 | +def sampleAttentionShape(): |
| 84 | + g = random.randint(1, 256) # GROUPS |
| 85 | + seqLenK = random.randint(1, 16384) # SEQ_LEN_K |
| 86 | + |
| 87 | + useKVCache = random.choice(BOOLS) |
| 88 | + currentSeqLen = genCurrentSeqLens(g, seqLenK) if useKVCache else None |
| 89 | + seqLenQ = 1 if useKVCache else random.randint(1, 16384) # SEQ_LEN_Q |
| 90 | + |
| 91 | + numHeadsQ = 1 |
| 92 | + numHeadsKV = 1 |
| 93 | + '''By default numHeadsQ and numHeadsKV are both 1. If numHeadsQ |
| 94 | + and numHeadsKV are equal GQA is disabled. Both values are powers |
| 95 | + of 2 typically. And numHeadsQ is divisible by numHeadsKV |
| 96 | + Here we decide randomly if we will use numHeadsQ and numHeadsKV |
| 97 | + different from the default values. |
| 98 | + |
| 99 | + Requirements: |
| 100 | + - numHeadsQ >= numHeadsKV |
| 101 | + - numHeadsQ % numHeadsKV == 0''' |
| 102 | + genNumHeads = random.choice(BOOLS) |
| 103 | + if genNumHeads: |
| 104 | + while True: |
| 105 | + numHeadsQ = 2**random.randint(1, 6) |
| 106 | + numHeadsKV = 2**random.randint(1, 6) |
| 107 | + |
| 108 | + if numHeadsQ > numHeadsKV and numHeadsQ%numHeadsKV == 0: # found valid case |
| 109 | + break |
| 110 | + |
| 111 | + return ( |
| 112 | + random.choice(DATA_TYPES_ATTENTION), |
| 113 | + g, # GROUPS |
| 114 | + seqLenQ, # SEQ_LEN_Q |
| 115 | + seqLenK, # SEQ_LEN_K |
| 116 | + numHeadsQ, # NUM_HEADS_Q |
| 117 | + numHeadsKV, # NUM_HEADS_KV |
| 118 | + random.randint(1, 1024), # HEAD_DIM_QK |
| 119 | + random.randint(1, 1024), # HEAD_DIM_V |
| 120 | + random.choice(BOOLS), # with_attn_scale |
| 121 | + random.choice(BOOLS), # with_attn_bias |
| 122 | + random.choice(BOOLS), # transQ |
| 123 | + random.choice(BOOLS), # transK |
| 124 | + random.choice(BOOLS), # transV |
| 125 | + random.choice(BOOLS), # transO |
| 126 | + random.choice(BOOLS), # causal |
| 127 | + random.choice(BOOLS), # return_lse |
| 128 | + currentSeqLen |
| 129 | + ) |
| 130 | + |
| 131 | +# Keep in sync with RockTuningImpl.cpp |
| 132 | +perfConfigSpaceMFMA = list(itertools.product( # MFMA perfConfig space |
| 133 | + [32, 64, 128, 256], # M/block G0 |
| 134 | + [32, 64, 128, 256], # M/block G1 |
| 135 | + [32, 64, 128, 256], # N/block G0 |
| 136 | + [8, 16, 32, 64], # Kpack/Block |
| 137 | + [32, 64, 128, 256], # M/Wave |
| 138 | + [4, 16, 32], # MN/Xdl |
| 139 | + [4, 8, 16], # kPack |
| 140 | + [0, 1] # forceUnroll |
| 141 | + )) |
| 142 | + |
| 143 | +perfConfigSpaceWMMA = list(itertools.product( # WMMA perfConfig space |
| 144 | + [32, 64, 128], # M/block G0 |
| 145 | + [32, 64, 128], # M/block G1 |
| 146 | + [32, 64, 128, 256], # N/block G0 |
| 147 | + [8, 16, 32, 64], # Kpack/Block |
| 148 | + [32, 64], # M/Wave |
| 149 | + [32, 64], # N/Wave |
| 150 | + [4, 8, 16], # kPack |
| 151 | + [0, 1] # forceUnroll |
| 152 | + )) |
| 153 | + |
| 154 | +def logFailingConfigs(configs: List[AttentionConfiguration], filename: str): |
| 155 | + with open(filename, mode='w', newline='') as csvfile: |
| 156 | + writer = csv.writer(csvfile) |
| 157 | + writer.writerow(['CommandLine']) |
| 158 | + for config in configs: |
| 159 | + writer.writerow([' '.join(config.generateMlirDriverCommandLine(''))]) |
| 160 | + |
| 161 | +def main(): |
| 162 | + parser = argparse.ArgumentParser( |
| 163 | + description='Sweep parameter values for attention to detect bugs') |
| 164 | + parser.add_argument('--debug', action='store_true') |
| 165 | + parser.add_argument('--quiet', action='store_true') |
| 166 | + parser.add_argument('--jobs', type=int, default=os.cpu_count()) |
| 167 | + parser.add_argument('--mlir-build-dir', type=str, default=findMlirBuildDir()), |
| 168 | + parser.add_argument('--samples', type=int, default=1000) |
| 169 | + parser.add_argument('--log-failures', action='store_true') |
| 170 | + |
| 171 | + args = parser.parse_args() |
| 172 | + arch = getArch() |
| 173 | + chip_match = GFX_CHIP_RE.search(arch) |
| 174 | + if chip_match is None: |
| 175 | + raise RuntimeError(f"Could not find GFX chip in arch string: {arch}") |
| 176 | + chip = chip_match.group(0) |
| 177 | + paths = createPaths(None, args.mlir_build_dir) |
| 178 | + options = Options( |
| 179 | + debug=args.debug, |
| 180 | + quiet=args.quiet, |
| 181 | + arch=arch, |
| 182 | + flags=[], |
| 183 | + concurrent_tests=args.jobs, |
| 184 | + numCu=getNumCU(chip) |
| 185 | + ) |
| 186 | + |
| 187 | + |
| 188 | + if not args.quiet: |
| 189 | + print(f"Sampling {args.samples} configurations from attention space...") |
| 190 | + |
| 191 | + # TODO: use AmdArchDb python version when available |
| 192 | + |
| 193 | + if chip.startswith('gfx9'): |
| 194 | + perfConfigSpace = perfConfigSpaceMFMA |
| 195 | + else: |
| 196 | + perfConfigSpace = perfConfigSpaceWMMA |
| 197 | + |
| 198 | + samples = [ |
| 199 | + (sampleAttentionShape(), random.choice(perfConfigSpace)) |
| 200 | + for _ in range(args.samples) |
| 201 | + ] |
| 202 | + |
| 203 | + passed, invalid, failing = asyncio.run(sweepParameters(samples, toAttentionConfig, options, paths)) |
| 204 | + if failing: |
| 205 | + print("\n" + "-" * 80) |
| 206 | + print(f"{'Failing Configurations':^80}\n") |
| 207 | + for fail in failing: |
| 208 | + print(multilineRepr(fail)) |
| 209 | + if args.log_failures: |
| 210 | + logFailingConfigs(failing, LOGFILE) |
| 211 | + |
| 212 | + print(f"\nPassed: {passed}, Invalid: {invalid}, Failed: {len(failing)}") |
| 213 | + |
| 214 | + return 0 |
| 215 | + |
| 216 | +if __name__ == '__main__': |
| 217 | + ret = main() |
| 218 | + sys.exit(ret) |
0 commit comments