Skip to content

Commit b16004d

Browse files
authored
Introduce perfConfig V3 with param to select different schedule (#1767)
This PR adds place holder for selecting different GEMM schedule as tuning param. Also adds param for output swizzle which may be used later.
1 parent ec6c1cc commit b16004d

File tree

112 files changed

+782
-614
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

112 files changed

+782
-614
lines changed

mlir/include/mlir/Dialect/Rock/IR/RockAttrDefs.td

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,9 @@ def Rock_GeneralGemmParamsAttr : Rock_Attr<"GeneralGemmParams", [RockTuningParam
225225
- mPerThread: The number of values of m to process as a unit on each thread
226226
- nPerThread: The number of values of n to process as a unit on each thread
227227
- kpack: The number of values of k to pack contiguously into the shared buffer
228+
- splitKFactor: Split-k factor for the Split-k GEMM algorithm
229+
- scheduleVersion: Param to select GEMM schedule
230+
- outputSwizzle: Whether to enable/disable output swizzle or use heuristics
228231
}];
229232
let parameters = (ins
230233
"uint32_t":$blockSize,
@@ -235,18 +238,23 @@ def Rock_GeneralGemmParamsAttr : Rock_Attr<"GeneralGemmParams", [RockTuningParam
235238
"int64_t":$mPerThread,
236239
"int64_t":$nPerThread,
237240
"int64_t":$kpack,
238-
"int64_t":$splitKFactor
241+
"int64_t":$splitKFactor,
242+
"int64_t":$scheduleVersion,
243+
"int64_t":$outputSwizzle
239244
);
240245

241246
let extraClassDeclaration = [{
242247
void getPerfConfigStr(::llvm::SmallVectorImpl<char> &perfStr) {
243-
("v2:" + Twine(getBlockSize()) + ","
248+
("v3:" + Twine(getBlockSize()) + ","
244249
+ Twine(getMPerBlock()) + ","
245250
+ Twine(getNPerBlock()) + ","
246251
+ Twine(getKPerBlock()) + ","
247252
+ Twine(getMPerThread()) + ","
248253
+ Twine(getNPerThread()) + ","
249-
+ Twine(getSplitKFactor())).toVector(perfStr);
254+
+ Twine(getSplitKFactor()) + ","
255+
+ Twine(getScheduleVersion()) + ","
256+
+ Twine(getOutputSwizzle()))
257+
.toVector(perfStr);
250258
}
251259
bool getForceUnroll() { return true; }
252260
}];
@@ -286,6 +294,8 @@ def Rock_AttnPerfConfig : Rock_Attr<"AttnPerfConfig", [RockTuningParamAttrInterf
286294
}
287295

288296
int64_t getSplitKFactor() { return 1; }
297+
int64_t getScheduleVersion() { return 1;}
298+
int64_t getOutputSwizzle() { return 2; }
289299
}];
290300

291301
let builders = [
@@ -310,18 +320,22 @@ def Rock_XdlopsGemmParamsAttr : Rock_Attr<"XdlopsGemmParams", [RockTuningParamAt
310320
"int64_t":$mPerWave,
311321
"int64_t":$mnPerXdl,
312322
"int64_t":$splitKFactor,
323+
"int64_t":$scheduleVersion,
324+
"int64_t":$outputSwizzle,
313325
"bool":$forceUnroll
314326
);
315327

316328
let extraClassDeclaration = [{
317329
void getPerfConfigStr(::llvm::SmallVectorImpl<char> &perfStr) {
318-
("v2:" + Twine(getMPerBlock()) + ","
330+
("v3:" + Twine(getMPerBlock()) + ","
319331
+ Twine(getNPerBlock()) + ","
320332
+ Twine(getKpackPerBlock()) + ","
321333
+ Twine(getMPerWave()) + ","
322334
+ Twine(getMnPerXdl()) + ","
323335
+ Twine(getKpack()) + ","
324336
+ Twine(getSplitKFactor()) + ","
337+
+ Twine(getScheduleVersion()) + ","
338+
+ Twine(getOutputSwizzle()) + ","
325339
+ Twine(getForceUnroll()) + ","
326340
+ "1") /* *ThreadCopyMore* */
327341
.toVector(perfStr);
@@ -347,18 +361,22 @@ def Rock_XdlopsGemmDerivedParamsAttr : Rock_Attr<"XdlopsGemmDerivedParams", [Roc
347361
"int64_t":$nPerWave,
348362
"int64_t":$mnPerXdl,
349363
"int64_t":$splitKFactor,
364+
"int64_t":$scheduleVersion,
365+
"int64_t":$outputSwizzle,
350366
"bool":$forceUnroll
351367
);
352368

353369
let extraClassDeclaration = [{
354370
void getPerfConfigStr(::llvm::SmallVectorImpl<char> &perfStr) {
355-
("v2:" + Twine(getMPerBlock()) + ","
371+
("v3:" + Twine(getMPerBlock()) + ","
356372
+ Twine(getNPerBlock()) + ","
357373
+ Twine(getKpackPerBlock()) + ","
358374
+ Twine(getMPerWave()) + ","
359375
+ Twine(getNPerWave()) + ","
360376
+ Twine(getKpack()) + ","
361377
+ Twine(getSplitKFactor()) + ","
378+
+ Twine(getScheduleVersion()) + ","
379+
+ Twine(getOutputSwizzle()) + ","
362380
+ Twine(getForceUnroll()) + ","
363381
+ "1") /* *ThreadCopyMore* */
364382
.toVector(perfStr);
@@ -386,6 +404,8 @@ def Rock_XdlopsGemmDerivedParamsAttr : Rock_Attr<"XdlopsGemmDerivedParams", [Roc
386404
nPerWave,
387405
mnPerXdl,
388406
params.getSplitKFactor(),
407+
params.getScheduleVersion(),
408+
params.getOutputSwizzle(),
389409
params.getForceUnroll()
390410
);
391411
}]>
@@ -409,18 +429,22 @@ def Rock_WmmaGemmParamsAttr : Rock_Attr<"WmmaGemmParams", [RockTuningParamAttrIn
409429
"int64_t":$mPerWave,
410430
"int64_t":$nPerWave,
411431
"int64_t":$splitKFactor,
432+
"int64_t":$scheduleVersion,
433+
"int64_t":$outputSwizzle,
412434
"bool":$forceUnroll
413435
);
414436

415437
let extraClassDeclaration = [{
416438
void getPerfConfigStr(SmallVectorImpl<char> &perfStr) {
417-
("v2:" + Twine(getMPerBlock()) + ","
439+
("v3:" + Twine(getMPerBlock()) + ","
418440
+ Twine(getNPerBlock()) + ","
419441
+ Twine(getKpackPerBlock()) + ","
420442
+ Twine(getMPerWave()) + ","
421443
+ Twine(getNPerWave()) + ","
422444
+ Twine(getKpack()) + ","
423445
+ Twine(getSplitKFactor()) + ","
446+
+ Twine(getScheduleVersion()) + ","
447+
+ Twine(getOutputSwizzle()) + ","
424448
+ Twine(getForceUnroll()) + ","
425449
+ "1") /* *ThreadCopyMore* */
426450
.toVector(perfStr);

mlir/include/mlir/Dialect/Rock/IR/RockTuningParamAttrInterface.td

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,29 @@ def RockTuningParamAttrInterface : AttrInterface<"RockTuningParamAttrInterface">
6969
/*methodBody=*/"",
7070
/*defaultImplementation=*/""
7171
>,
72-
72+
InterfaceMethod<
73+
/*desc=*/[{
74+
Returns value 0 or 1 to decide if rocMLIR should enable output swizzle or not.
75+
Return value 2 means rocMLIR should use heuristics to make this decision
76+
}],
77+
/*retType=*/"int64_t",
78+
/*methodName=*/"getOutputSwizzle",
79+
/*args=*/(ins),
80+
/*methodBody=*/"",
81+
/*defaultImplementation=*/""
82+
>,
83+
InterfaceMethod<
84+
/*desc=*/[{
85+
Returns version of the schedule used by the underlying GEMM algorithm.
86+
Return value of 1 means it should use 3 stage schedule with II=2.
87+
Return value of 2 means it should 4 stage schedule with II=1.
88+
}],
89+
/*retType=*/"int64_t",
90+
/*methodName=*/"getScheduleVersion",
91+
/*args=*/(ins),
92+
/*methodBody=*/"",
93+
/*defaultImplementation=*/""
94+
>,
7395
// TODO: more methods here as needed
7496
];
7597

mlir/include/mlir/Dialect/Rock/Tuning/GridwiseGemmParams.h

Lines changed: 42 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -101,24 +101,30 @@ struct InitParamsNonAccel : InitParams, Serializable<InitParamsNonAccel> {
101101
int64_t gemmNPerThread;
102102
uint32_t blockSize;
103103
int64_t splitKFactor;
104+
int64_t gemmScheduleVersion;
105+
int64_t outputSwizzle;
104106

105107
constexpr InitParamsNonAccel(uint32_t bSize, int64_t mPerBlock,
106108
int64_t nPerBlock, int64_t kPerBlock,
107109
int64_t mPerThread, int64_t nPerThread,
108-
int64_t splitKFactor)
110+
int64_t splitKFactor, int64_t scheduleVersion,
111+
int64_t outputSwizzle)
109112
: InitParams{mPerBlock, nPerBlock, kPerBlock}, gemmMPerThread(mPerThread),
110113
gemmNPerThread(nPerThread), blockSize(bSize),
111-
splitKFactor(splitKFactor) {}
114+
splitKFactor(splitKFactor), gemmScheduleVersion(scheduleVersion),
115+
outputSwizzle(outputSwizzle) {}
112116

113117
constexpr InitParamsNonAccel()
114-
: InitParamsNonAccel(0U, 0LL, 0LL, 0LL, 0LL, 0LL, 1LL) {}
118+
: InitParamsNonAccel(0U, 0LL, 0LL, 0LL, 0LL, 0LL, 1LL, 1LL, 2LL) {}
115119

116120
InitParamsNonAccel(GeneralGemmParamsAttr attr)
117121
: InitParams{attr.getMPerBlock(), attr.getNPerBlock(),
118122
attr.getKPerBlock()},
119123
gemmMPerThread(attr.getMPerThread()),
120124
gemmNPerThread(attr.getNPerThread()), blockSize(attr.getBlockSize()),
121-
splitKFactor(attr.getSplitKFactor()){};
125+
splitKFactor(attr.getSplitKFactor()),
126+
gemmScheduleVersion(attr.getScheduleVersion()),
127+
outputSwizzle(attr.getOutputSwizzle()){};
122128

123129
int64_t getKPack() { return 1; }
124130

@@ -130,32 +136,41 @@ struct InitParamsNonAccel : InitParams, Serializable<InitParamsNonAccel> {
130136
f(self.gemmKPerBlock);
131137
f(self.gemmMPerThread);
132138
f(self.gemmNPerThread);
133-
if (self.version != Version::V1)
139+
if (self.version >= Version::V2)
134140
f(self.splitKFactor);
141+
if (self.version >= Version::V3) {
142+
f(self.gemmScheduleVersion);
143+
f(self.outputSwizzle);
144+
}
135145
}
136146
};
137147

138148
struct InitParamsAccel : InitParams, Serializable<InitParamsAccel> {
139149
constexpr InitParamsAccel(int64_t mPerBlock, int64_t nPerBlock,
140150
int64_t kPerBlock, int64_t mPerWave,
141151
int64_t nPerWaveOrMnPerXdl, int64_t kPack,
142-
int64_t splitKFactor, bool aThreadCopyMoreGemmK,
152+
int64_t splitKFactor, int64_t scheduleVersion,
153+
int64_t outputSwizzle, bool aThreadCopyMoreGemmK,
143154
bool bThreadCopyMoreGemmKPack)
144155
: InitParams{mPerBlock, nPerBlock, kPerBlock}, gemmMPerWave(mPerWave),
145156
gemmNPerWaveOrMnPerXdl(nPerWaveOrMnPerXdl), gemmKPack(kPack),
146-
splitKFactor(splitKFactor),
157+
splitKFactor(splitKFactor), gemmScheduleVersion(scheduleVersion),
158+
outputSwizzle(outputSwizzle),
147159
gemmAThreadCopyMoreGemmK(aThreadCopyMoreGemmK),
148160
gemmBThreadCopyMoreGemmKPack(bThreadCopyMoreGemmKPack) {}
149161

150162
constexpr InitParamsAccel()
151-
: InitParamsAccel(0LL, 0LL, 0LL, 0LL, 0LL, 0LL, 1LL, false, false) {}
163+
: InitParamsAccel(0LL, 0LL, 0LL, 0LL, 0LL, 0LL, 1LL, 1LL, 2LL, false,
164+
false) {}
152165

153166
InitParamsAccel(XdlopsGemmParamsAttr attr)
154167
: InitParams{attr.getMPerBlock(), attr.getNPerBlock(),
155168
attr.getKpackPerBlock()},
156169
gemmMPerWave(attr.getMPerWave()),
157170
gemmNPerWaveOrMnPerXdl(attr.getMnPerXdl()), gemmKPack(attr.getKpack()),
158171
splitKFactor(attr.getSplitKFactor()),
172+
gemmScheduleVersion(attr.getScheduleVersion()),
173+
outputSwizzle(attr.getOutputSwizzle()),
159174
gemmAThreadCopyMoreGemmK(attr.getForceUnroll()),
160175
gemmBThreadCopyMoreGemmKPack(false){};
161176

@@ -165,6 +180,8 @@ struct InitParamsAccel : InitParams, Serializable<InitParamsAccel> {
165180
gemmMPerWave(attr.getMPerWave()),
166181
gemmNPerWaveOrMnPerXdl(attr.getNPerWave()), gemmKPack(attr.getKpack()),
167182
splitKFactor(attr.getSplitKFactor()),
183+
gemmScheduleVersion(attr.getScheduleVersion()),
184+
outputSwizzle(attr.getOutputSwizzle()),
168185
gemmAThreadCopyMoreGemmK(attr.getForceUnroll()),
169186
gemmBThreadCopyMoreGemmKPack(false){};
170187

@@ -174,6 +191,8 @@ struct InitParamsAccel : InitParams, Serializable<InitParamsAccel> {
174191
int64_t gemmNPerWaveOrMnPerXdl;
175192
int64_t gemmKPack;
176193
int64_t splitKFactor;
194+
int64_t gemmScheduleVersion;
195+
int64_t outputSwizzle;
177196
bool gemmAThreadCopyMoreGemmK;
178197
bool gemmBThreadCopyMoreGemmKPack;
179198

@@ -185,9 +204,13 @@ struct InitParamsAccel : InitParams, Serializable<InitParamsAccel> {
185204
f(self.gemmMPerWave);
186205
f(self.gemmNPerWaveOrMnPerXdl);
187206
f(self.gemmKPack);
188-
if (self.version != Version::V1) {
207+
if (self.version >= Version::V2) {
189208
f(self.splitKFactor);
190209
}
210+
if (self.version >= Version::V3) {
211+
f(self.gemmScheduleVersion);
212+
f(self.outputSwizzle);
213+
}
191214
f(self.gemmAThreadCopyMoreGemmK);
192215
f(self.gemmBThreadCopyMoreGemmKPack);
193216
}
@@ -206,14 +229,15 @@ class BasePopulateParams {
206229
private:
207230
struct InitParamData {
208231
InitParamType paramSet;
209-
size_t original_pos;
210-
int64_t padding_amount;
232+
size_t originalPos;
233+
int64_t paddingAmount;
211234

212235
bool operator<(const InitParamData &rhs) const {
213-
if (this->padding_amount < rhs.padding_amount) {
236+
if (this->paddingAmount < rhs.paddingAmount) {
214237
return true;
215-
} else if (this->padding_amount == rhs.padding_amount) {
216-
return (this->original_pos < rhs.original_pos);
238+
}
239+
if (this->paddingAmount == rhs.paddingAmount) {
240+
return (this->originalPos < rhs.originalPos);
217241
}
218242
return false;
219243
}
@@ -227,10 +251,10 @@ class BasePopulateParams {
227251
InitParamType paramSet = initParams[pos];
228252
InitParamData paramData;
229253
paramData.paramSet = paramSet;
230-
paramData.original_pos = pos;
231-
paramData.padding_amount = calculatePaddingAmount(paramSet, gemmSize);
232-
assert(paramData.original_pos >= 0);
233-
assert(paramData.padding_amount >= 0);
254+
paramData.originalPos = pos;
255+
paramData.paddingAmount = calculatePaddingAmount(paramSet, gemmSize);
256+
assert(paramData.originalPos >= 0);
257+
assert(paramData.paddingAmount >= 0);
234258
res.push_back(paramData);
235259
}
236260
return res;

0 commit comments

Comments
 (0)