Skip to content

Commit 2c251fd

Browse files
authored
Parameter Sweep for Attention (#1830)
Implement parameterSweep for attention script which tests combinations of input shapes and perfConfigs for attention. --------- Signed-off-by: Djordje Antic <[email protected]>
1 parent dc212ab commit 2c251fd

File tree

4 files changed

+277
-8
lines changed

4 files changed

+277
-8
lines changed

mlir/lib/Dialect/Rock/Tuning/RockTuningImpl.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
namespace mlir {
3232
namespace rock {
3333

34+
// Keep in sync with attentionSweeps.py
3435
// The full space is a brute-force search for attention kernels
3536
template <typename Op>
3637
static void createAttnTuningRangeBF(TuningParamSet *newSpace, Op attnOp,
Lines changed: 218 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,218 @@
1+
#!/usr/bin/env python3
2+
"""Sweeps the parameters of the rocmlir driver for bugs for attention-based kernel configurations.
3+
4+
Usage:
5+
python3 attentionSweeps.py --mlir-build-dir <path-to-mlir-build-dir> [options]
6+
7+
Options:
8+
--mlir-build-dir Path to the MLIR build directory (default: auto-detected)
9+
--samples Number of random configuration samples to the test (default: 1000)
10+
--jobs Number of concurrent tests to run in parallel (default: os.cpu_count())
11+
--debug Enable debug output
12+
--quiet Disable per-test result output
13+
--log-failures Save failing configurations to csv file
14+
"""
15+
import argparse
16+
import itertools
17+
import asyncio
18+
from typing import Iterable, List, TypeVar
19+
from datetime import datetime
20+
import sys
21+
import csv
22+
import random
23+
import os
24+
25+
from perfRunner import AttentionConfiguration
26+
from perfRunner import getArch, getNumCU, initializeDataTypesAttention
27+
from perfRunner import create_paths as createPaths
28+
from perfRunner import find_mlir_build_dir as findMlirBuildDir
29+
from perfRunner import DATA_TYPES_ATTENTION, GFX_CHIP_RE
30+
from parameterSweeps import Options, sweepParameters, multilineRepr
31+
32+
# GLOBAL VARIABLES
33+
DATA_TYPES_ATTENTION = initializeDataTypesAttention()
34+
BOOLS = [True, False]
35+
LOGFILE = 'failing_configs.csv'
36+
37+
# Week number is used as seed to make sure weekly CI is reproducible
38+
seed = datetime.utcnow().isocalendar()[1]
39+
random.seed(seed)
40+
41+
def toAttentionConfig(params, options: Options) -> AttentionConfiguration:
42+
"""Converts a sampled parameter tuple into a AttentionConfiguration instance."""
43+
shape, perf = params
44+
*shapeParams, currentSeqLen = shape
45+
dtype, g, slq, slk, nhq, nhkv, hdqk, hdv, scale, bias, tq, tk, tv, to, causal, rlse = shapeParams
46+
perfString = f"attn:v1:{','.join(str(x) for x in perf)}"
47+
attnConfig = AttentionConfiguration(
48+
dtype=dtype,
49+
g=g,
50+
seq_len_q=slq,
51+
seq_len_k=slk,
52+
num_heads_q=nhq,
53+
num_heads_kv=nhkv,
54+
head_dim_qk=hdqk,
55+
head_dim_v=hdv,
56+
with_attn_scale=scale,
57+
with_attn_bias=bias,
58+
transQ=tq,
59+
transK=tk,
60+
transV=tv,
61+
transO=to,
62+
causal=causal,
63+
return_lse=rlse,
64+
arch=options.arch,
65+
numCU=options.numCu,
66+
perf_config=perfString
67+
)
68+
attnConfig.currentSeqLen = currentSeqLen
69+
return attnConfig
70+
71+
IterType = TypeVar('IterType')
72+
def grouper(iterable: Iterable[IterType], n: int):
73+
it = iter(iterable)
74+
while True:
75+
chunk = tuple(itertools.islice(it, n))
76+
if not chunk:
77+
return
78+
yield chunk
79+
80+
def genCurrentSeqLens(g: int, maxSeqLen: int) -> list[int]:
81+
return [random.randint(0, maxSeqLen-1) for _ in range(g)]
82+
83+
def sampleAttentionShape():
84+
g = random.randint(1, 256) # GROUPS
85+
seqLenK = random.randint(1, 16384) # SEQ_LEN_K
86+
87+
useKVCache = random.choice(BOOLS)
88+
currentSeqLen = genCurrentSeqLens(g, seqLenK) if useKVCache else None
89+
seqLenQ = 1 if useKVCache else random.randint(1, 16384) # SEQ_LEN_Q
90+
91+
numHeadsQ = 1
92+
numHeadsKV = 1
93+
'''By default numHeadsQ and numHeadsKV are both 1. If numHeadsQ
94+
and numHeadsKV are equal GQA is disabled. Both values are powers
95+
of 2 typically. And numHeadsQ is divisible by numHeadsKV
96+
Here we decide randomly if we will use numHeadsQ and numHeadsKV
97+
different from the default values.
98+
99+
Requirements:
100+
- numHeadsQ >= numHeadsKV
101+
- numHeadsQ % numHeadsKV == 0'''
102+
genNumHeads = random.choice(BOOLS)
103+
if genNumHeads:
104+
while True:
105+
numHeadsQ = 2**random.randint(1, 6)
106+
numHeadsKV = 2**random.randint(1, 6)
107+
108+
if numHeadsQ > numHeadsKV and numHeadsQ%numHeadsKV == 0: # found valid case
109+
break
110+
111+
return (
112+
random.choice(DATA_TYPES_ATTENTION),
113+
g, # GROUPS
114+
seqLenQ, # SEQ_LEN_Q
115+
seqLenK, # SEQ_LEN_K
116+
numHeadsQ, # NUM_HEADS_Q
117+
numHeadsKV, # NUM_HEADS_KV
118+
random.randint(1, 1024), # HEAD_DIM_QK
119+
random.randint(1, 1024), # HEAD_DIM_V
120+
random.choice(BOOLS), # with_attn_scale
121+
random.choice(BOOLS), # with_attn_bias
122+
random.choice(BOOLS), # transQ
123+
random.choice(BOOLS), # transK
124+
random.choice(BOOLS), # transV
125+
random.choice(BOOLS), # transO
126+
random.choice(BOOLS), # causal
127+
random.choice(BOOLS), # return_lse
128+
currentSeqLen
129+
)
130+
131+
# Keep in sync with RockTuningImpl.cpp
132+
perfConfigSpaceMFMA = list(itertools.product( # MFMA perfConfig space
133+
[32, 64, 128, 256], # M/block G0
134+
[32, 64, 128, 256], # M/block G1
135+
[32, 64, 128, 256], # N/block G0
136+
[8, 16, 32, 64], # Kpack/Block
137+
[32, 64, 128, 256], # M/Wave
138+
[4, 16, 32], # MN/Xdl
139+
[4, 8, 16], # kPack
140+
[0, 1] # forceUnroll
141+
))
142+
143+
perfConfigSpaceWMMA = list(itertools.product( # WMMA perfConfig space
144+
[32, 64, 128], # M/block G0
145+
[32, 64, 128], # M/block G1
146+
[32, 64, 128, 256], # N/block G0
147+
[8, 16, 32, 64], # Kpack/Block
148+
[32, 64], # M/Wave
149+
[32, 64], # N/Wave
150+
[4, 8, 16], # kPack
151+
[0, 1] # forceUnroll
152+
))
153+
154+
def logFailingConfigs(configs: List[AttentionConfiguration], filename: str):
155+
with open(filename, mode='w', newline='') as csvfile:
156+
writer = csv.writer(csvfile)
157+
writer.writerow(['CommandLine'])
158+
for config in configs:
159+
writer.writerow([' '.join(config.generateMlirDriverCommandLine(''))])
160+
161+
def main():
162+
parser = argparse.ArgumentParser(
163+
description='Sweep parameter values for attention to detect bugs')
164+
parser.add_argument('--debug', action='store_true')
165+
parser.add_argument('--quiet', action='store_true')
166+
parser.add_argument('--jobs', type=int, default=os.cpu_count())
167+
parser.add_argument('--mlir-build-dir', type=str, default=findMlirBuildDir()),
168+
parser.add_argument('--samples', type=int, default=1000)
169+
parser.add_argument('--log-failures', action='store_true')
170+
171+
args = parser.parse_args()
172+
arch = getArch()
173+
chip_match = GFX_CHIP_RE.search(arch)
174+
if chip_match is None:
175+
raise RuntimeError(f"Could not find GFX chip in arch string: {arch}")
176+
chip = chip_match.group(0)
177+
paths = createPaths(None, args.mlir_build_dir)
178+
options = Options(
179+
debug=args.debug,
180+
quiet=args.quiet,
181+
arch=arch,
182+
flags=[],
183+
concurrent_tests=args.jobs,
184+
numCu=getNumCU(chip)
185+
)
186+
187+
188+
if not args.quiet:
189+
print(f"Sampling {args.samples} configurations from attention space...")
190+
191+
# TODO: use AmdArchDb python version when available
192+
193+
if chip.startswith('gfx9'):
194+
perfConfigSpace = perfConfigSpaceMFMA
195+
else:
196+
perfConfigSpace = perfConfigSpaceWMMA
197+
198+
samples = [
199+
(sampleAttentionShape(), random.choice(perfConfigSpace))
200+
for _ in range(args.samples)
201+
]
202+
203+
passed, invalid, failing = asyncio.run(sweepParameters(samples, toAttentionConfig, options, paths))
204+
if failing:
205+
print("\n" + "-" * 80)
206+
print(f"{'Failing Configurations':^80}\n")
207+
for fail in failing:
208+
print(multilineRepr(fail))
209+
if args.log_failures:
210+
logFailingConfigs(failing, LOGFILE)
211+
212+
print(f"\nPassed: {passed}, Invalid: {invalid}, Failed: {len(failing)}")
213+
214+
return 0
215+
216+
if __name__ == '__main__':
217+
ret = main()
218+
sys.exit(ret)

mlir/utils/performance/parameterSweeps.py

Lines changed: 56 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ class Options:
3434
arch: str
3535
flags: list
3636
concurrent_tests: int
37+
numCu: int
3738

3839
class PerfConfig:
3940
class Version(enum.Enum):
@@ -137,16 +138,61 @@ def __init__(self, dtype: str, direction: str, layout: str,
137138
self.ho = math.floor((self.hi + self.paddingHL + self.paddingHR - (self.y - 1) * self.dilationH - 1 ) / self.convStrideH) + 1
138139
self.wo = math.floor((self.wi + self.paddingWL + self.paddingWR * 2 - (self.x - 1) * self.dilationW - 1 ) / self.convStrideW) + 1
139140

141+
def multilineRepr(obj, num_fields=4):
142+
""" Returns a multi-line string representation of the given object,
143+
inserting a newline after every defined number of comma-separated
144+
fields in its repr(). Useful for making long configuration
145+
representations more readable in logs or debug output."""
146+
s = repr(obj).replace('\n', ' ') # Flatten to one line
147+
lines = []
148+
field = ''
149+
fields = []
150+
in_quotes = False
151+
perf_config_str = None
152+
153+
i = 0
154+
while i < len(s):
155+
# Detect start of perf_config to prevent it from being split
156+
if s.startswith('perf_config=', i):
157+
perf_config_str = s[i:]
158+
break
159+
c = s[i]
160+
if c == "'":
161+
in_quotes = not in_quotes
162+
field += c
163+
elif c == ',' and not in_quotes:
164+
fields.append(field.strip() + ',')
165+
field = ''
166+
else:
167+
field += c
168+
i += 1
169+
if field:
170+
fields.append(field.strip())
171+
for j in range(0, len(fields), num_fields):
172+
prefix = '\t' if j > 0 else ''
173+
group = fields[j:j+num_fields]
174+
if j + num_fields >= len(fields) and group and group[-1].endswith(','):
175+
group[-1] = group[-1][:-1]
176+
lines.append(f"{prefix}{' '.join(group)}")
177+
if perf_config_str:
178+
lines.append('\t' + perf_config_str.strip())
179+
180+
return '\n'.join(lines)
140181

141182
class TestResult(enum.Enum):
142183
PASS = 1
143184
INVALID = 2
144185
FAIL = 3
145186

146-
async def testConfig(config: MLIROnlyConfig, options: Options, paths: Paths) -> TestResult:
187+
async def testConfig(config, options: Options, paths: Paths) -> TestResult:
147188
"""Runs the given configuration and returns whether it successfully concluded,
148189
failed validation, or was inapplicable."""
149-
rocmlirGenOpts = config.generateMlirDriverCommandLine(options.flags)
190+
if isinstance(config, MLIROnlyConfig):
191+
rocmlirGenOpts = config.generateMlirDriverCommandLine(options.flags)
192+
else:
193+
rocmlirGenOpts = config.generateMlirDriverCommandLine(' '.join(options.flags)).split()
194+
if getattr(config, "currentSeqLen") is not None:
195+
rocmlirGenOpts.append(f"--current_seq_len={','.join(map(str, config.currentSeqLen))}")
150196
rocmlirGenOpts.append('-pv')
151197

152198
applicableFromGen, genToApplicable = os.pipe()
@@ -218,7 +264,7 @@ async def testConfig(config: MLIROnlyConfig, options: Options, paths: Paths) ->
218264
return TestResult.FAIL
219265

220266
if not CORRECT_RESULT_RE.search(runnerOut):
221-
print(f"""Convolution returned intorrect result
267+
print(f"""Config returned incorrect result
222268
Output = {runnerOut}
223269
Errors = {runnerErrs.decode('utf-8')}""", file=sys.stderr)
224270
return TestResult.FAIL
@@ -233,20 +279,22 @@ def grouper(iterable: Iterable[IterType], n: int):
233279
return
234280
yield chunk
235281

236-
async def dropGoodConfig(config: ConvConfiguration,
237-
options: Options, paths: Paths) -> Union[TestResult, ConvConfiguration]:
282+
async def dropGoodConfig(config, options: Options, paths: Paths):
238283
"""Test the given `params`, returning the corresponding `config` on failure
239284
and `None` on success or inapplicability"""
240285
result = await testConfig(config, options, paths)
241286
if not options.quiet:
242-
print(f"{result.name}: {config!r}")
287+
if isinstance(config, MLIROnlyConfig):
288+
print(f"{result.name}: {config!r}")
289+
else:
290+
print(f"{result.name}: {multilineRepr(config)}")
243291
if result == TestResult.FAIL:
244292
return config
245293
return result
246294

247295
async def sweepParameters(paramIter: Iterable[IterType],
248-
toConfig: Callable[[IterType, Options], MLIROnlyConfig],
249-
options: Options, paths: Paths) -> Tuple[int, int, List[MLIROnlyConfig]]:
296+
toConfig: Callable[[IterType, Options], PerfConfig],
297+
options: Options, paths: Paths) -> Tuple[int, int, List[PerfConfig]]:
250298
failingConfigs = []
251299
passed = 0
252300
invalid = 0

mlir/utils/performance/perfRunner.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,8 @@ def initializeDataTypesAttention():
155155
else:
156156
DATA_TYPES_ATTENTION = DATA_TYPES_ATTENTION_WMMA
157157

158+
return DATA_TYPES_ATTENTION # For modules that import this function
159+
158160
def create_paths(config_file_path, mlir_build_dir_path) -> Paths:
159161
"""Creates the composite Paths structure using build dir paths"""
160162

0 commit comments

Comments
 (0)