Skip to content

Commit 2e743e5

Browse files
[BACKPORT] Introduce new quick tune lists based on Tier1 configs and separated b… (#1938)
* Introduce new quick tune lists based on Tier1 configs and separated by architecture (#1907) Co-authored-by: Mirza Halilcevic <[email protected]>
1 parent d8e45e1 commit 2e743e5

25 files changed

+1523
-383
lines changed

mlir/include/mlir/Dialect/Rock/Tuning/GridwiseGemmParams.h

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ struct InitParamsNonAccel : InitParams, Serializable<InitParamsNonAccel> {
124124
gemmNPerThread(attr.getNPerThread()), blockSize(attr.getBlockSize()),
125125
splitKFactor(attr.getSplitKFactor()),
126126
gemmScheduleVersion(attr.getScheduleVersion()),
127-
outputSwizzle(attr.getOutputSwizzle()){};
127+
outputSwizzle(attr.getOutputSwizzle()) {};
128128

129129
int64_t getKPack() { return 1; }
130130

@@ -172,7 +172,7 @@ struct InitParamsAccel : InitParams, Serializable<InitParamsAccel> {
172172
gemmScheduleVersion(attr.getScheduleVersion()),
173173
outputSwizzle(attr.getOutputSwizzle()),
174174
gemmAThreadCopyMoreGemmK(attr.getForceUnroll()),
175-
gemmBThreadCopyMoreGemmKPack(false){};
175+
gemmBThreadCopyMoreGemmKPack(false) {};
176176

177177
InitParamsAccel(WmmaGemmParamsAttr attr)
178178
: InitParams{attr.getMPerBlock(), attr.getNPerBlock(),
@@ -183,7 +183,7 @@ struct InitParamsAccel : InitParams, Serializable<InitParamsAccel> {
183183
gemmScheduleVersion(attr.getScheduleVersion()),
184184
outputSwizzle(attr.getOutputSwizzle()),
185185
gemmAThreadCopyMoreGemmK(attr.getForceUnroll()),
186-
gemmBThreadCopyMoreGemmKPack(false){};
186+
gemmBThreadCopyMoreGemmKPack(false) {};
187187

188188
int64_t getKPack() { return gemmKPack; }
189189

@@ -333,8 +333,10 @@ class PopulateParams : public BasePopulateParams<InitParamsNonAccel> {
333333

334334
// Return the vector of heuristic parameters for a given kernel type and dat
335335
// type.
336-
std::vector<InitParamsNonAccel>
337-
getTuningParameters(KernelType opType, Type dataTypeA, Type dataTypeB) const;
336+
std::vector<InitParamsNonAccel> getTuningParameters(KernelType opType,
337+
Type dataTypeA,
338+
Type dataTypeB,
339+
StringRef arch) const;
338340

339341
Attribute getGemmParamsAttr(OpBuilder &b,
340342
const InitParamsNonAccel &params) const override;

0 commit comments

Comments
 (0)