Skip to content

Commit 219a818

Browse files
authored
Greedy tuning (#2131)
1 parent c6a56fe commit 219a818

File tree

10 files changed

+849
-426
lines changed

10 files changed

+849
-426
lines changed

mlir/include/mlir/Dialect/Rock/Tuning/GridwiseGemmParams.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -470,6 +470,11 @@ class PopulateParamsWmma : public PopulateParamsAccel {
470470
Type dataTypeB) override;
471471
};
472472

473+
FailureOr<std::pair<RockAccelTuningParamAttrInterface,
474+
RockAccelTuningParamAttrInterface>>
475+
getAttentionTuningParams(OpBuilder &b, RockGemmGemmWrapperInterface gemmGemmOp,
476+
AttnPerfConfigAttr attnPerfConfig);
477+
473478
} // namespace rock
474479
} // namespace mlir
475480
#endif // MLIR_DIALECT_ROCK_GRIDWISE_GEMM_PARAMS_H

mlir/include/mlir/Dialect/Rock/Tuning/RockTuning.h

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,13 @@ enum class TuningParamSetKind : uint32_t {
3232
// configurations that have been shown not to yield good performance.
3333
// (Note: this filtering is currently unimplemented).
3434
Full = 1,
35+
// Tune all possible tile sizes and try N random configurations for each tile
36+
// size. Then, greedily select the best tile size, and brute force tune the
37+
// rest of params
38+
Greedy = 2,
3539
// A tuning space consisting of all possible sets of tuning parameters,
3640
// excluding those that could not be applicable to the given problem.
37-
Exhaustive = 2,
41+
Exhaustive = 3,
3842
};
3943

4044
// Parameter container holding a parameter and serialized string
@@ -49,7 +53,20 @@ struct TuningParamSet {
4953
KernelType primaryOpType;
5054
};
5155

52-
TuningParamSet *createTunableParamSpace(ModuleOp mod, TuningParamSetKind kind);
56+
struct TuningParamSpaceSettings {
57+
unsigned iteration = 0;
58+
StringRef winningConfig = "";
59+
};
60+
61+
// Get the number of iterations needed for a given tuning kind
62+
unsigned getNumberOfIterations(TuningParamSetKind kind);
63+
64+
// Whether the tuning kind needs to have the best of previous iteration
65+
bool needToUpdateBest(TuningParamSetKind kind);
66+
67+
// Modified function signature to support multiple iterations
68+
TuningParamSet *createTunableParamSpace(ModuleOp mod, TuningParamSetKind kind,
69+
TuningParamSpaceSettings &settings);
5370
// Get a parameters from the set of tunable parameters.
5471
bool tuningGetParam(TuningParamSet *tuningSpace, unsigned pos,
5572
ParamEntry *paramEntry);

mlir/lib/CAPI/Dialect/Rock.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,8 @@ mlirRockTuningSpaceCreate(MlirModule module, RocmlirTuningParamSetKind kind) {
4444
break;
4545
}
4646
auto mod = unwrap(module);
47-
newParams = rock::createTunableParamSpace(mod, ourKind);
47+
rock::TuningParamSpaceSettings settings;
48+
newParams = rock::createTunableParamSpace(mod, ourKind, settings);
4849
return wrap(newParams);
4950
}
5051

mlir/lib/Dialect/Rock/Transforms/AffixTuningParameters.cpp

Lines changed: 9 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -288,37 +288,6 @@ void AffixTuningParameters::affixTuningParametersImpl(
288288
}
289289
}
290290

291-
static RockAccelTuningParamAttrInterface
292-
deriveGemm1TuningParams(OpBuilder &builder, RockGemmGemmWrapperInterface op,
293-
AttnPerfConfigAttr attnPerfConfig) {
294-
auto gemm0TuningParams =
295-
cast<RockAccelTuningParamAttrInterface>(op.getGemm0Params().value());
296-
int64_t gemm1KPack = gemm0TuningParams.getKpack();
297-
if (auto gemm0XdlDerivedParams =
298-
dyn_cast<MfmaGemmParamsAttr>(op.getGemm0Params().value())) {
299-
return MfmaGemmParamsAttr::get(
300-
builder.getContext(), gemm0TuningParams.getMPerBlock() / gemm1KPack,
301-
attnPerfConfig.getMPerBlockG1(), gemm0XdlDerivedParams.getNPerBlock(),
302-
gemm0TuningParams.getKpack(),
303-
gemm0TuningParams.getMPerWave() * (attnPerfConfig.getMPerBlockG1() /
304-
gemm0TuningParams.getMPerBlock()),
305-
gemm0XdlDerivedParams.getNPerWave(),
306-
gemm0XdlDerivedParams.getMnPerXdl(), attnPerfConfig.getSplitKFactor(),
307-
gemm0XdlDerivedParams.getScheduleVersion(),
308-
gemm0XdlDerivedParams.getOutputSwizzle(),
309-
gemm0XdlDerivedParams.getForceUnroll());
310-
}
311-
return WmmaGemmParamsAttr::get(
312-
builder.getContext(), gemm0TuningParams.getMPerBlock() / gemm1KPack,
313-
attnPerfConfig.getMPerBlockG1(), attnPerfConfig.getNPerBlockG0(),
314-
gemm0TuningParams.getKpack(),
315-
gemm0TuningParams.getMPerWave() *
316-
(attnPerfConfig.getMPerBlockG1() / gemm0TuningParams.getMPerBlock()),
317-
gemm0TuningParams.getNPerWave(), gemm0TuningParams.getMnPerXdl(),
318-
attnPerfConfig.getSplitKFactor(), gemm0TuningParams.getScheduleVersion(),
319-
gemm0TuningParams.getOutputSwizzle(), gemm0TuningParams.getForceUnroll());
320-
}
321-
322291
void AffixTuningParameters::affixTuningParametersImpl(
323292
RockGemmGemmWrapperInterface op) {
324293
OpBuilder builder(op.getContext());
@@ -364,55 +333,22 @@ void AffixTuningParameters::affixTuningParametersImpl(
364333
return signalPassFailure();
365334
}
366335

367-
GemmFeatures features = rock::getFeatures(op);
368-
RockAccelTuningParamAttrInterface accelParams0;
369-
if (bitEnumContainsAny(features, GemmFeatures::mfma)) {
370-
accelParams0 = MfmaGemmParamsAttr::get(
371-
builder.getContext(), attnPerfConfig.getKpackPerBlock(),
372-
attnPerfConfig.getMPerBlockG0(), attnPerfConfig.getNPerBlockG0(),
373-
attnPerfConfig.getKpack(), attnPerfConfig.getMPerWave(),
374-
attnPerfConfig.getNPerWave(), attnPerfConfig.getMnPerXdl(), 1,
375-
attnPerfConfig.getScheduleVersion(), attnPerfConfig.getOutputSwizzle(),
376-
attnPerfConfig.getForceUnroll());
377-
} else {
378-
accelParams0 = WmmaGemmParamsAttr::get(
379-
builder.getContext(), attnPerfConfig.getKpackPerBlock(),
380-
attnPerfConfig.getMPerBlockG0(), attnPerfConfig.getNPerBlockG0(),
381-
attnPerfConfig.getKpack(), attnPerfConfig.getMPerWave(),
382-
attnPerfConfig.getNPerWave(), attnPerfConfig.getMnPerXdl(), 1,
383-
attnPerfConfig.getScheduleVersion(), attnPerfConfig.getOutputSwizzle(),
384-
attnPerfConfig.getForceUnroll());
385-
}
386-
op.setGemm0ParamsAttr(accelParams0);
387-
if (attnPerfConfig.getMPerBlockG0() > attnPerfConfig.getMPerBlockG1()) {
388-
op.emitError(
389-
"The MPerBlockG0 should be larger or equal to getMPerBlockG1.");
336+
auto accelParams = getAttentionTuningParams(builder, op, attnPerfConfig);
337+
if (failed(accelParams)) {
338+
op.emitError("The provided perf config is not valid");
390339
return signalPassFailure();
391340
}
392-
RockAccelTuningParamAttrInterface accelParams1 =
393-
deriveGemm1TuningParams(builder, op, attnPerfConfig);
341+
RockAccelTuningParamAttrInterface accelParams0, accelParams1;
342+
accelParams0 = accelParams->first;
343+
accelParams1 = accelParams->second;
344+
LLVM_DEBUG(llvm::dbgs() << "accelParams0=" << accelParams0 << "\n");
345+
LLVM_DEBUG(llvm::dbgs() << "accelParams1=" << accelParams1 << "\n");
346+
op.setGemm0ParamsAttr(accelParams0);
394347
op.setGemm1ParamsAttr(accelParams1);
395348
int64_t waveSize = rock::lookupArchInfo(rock::getArchValue(op)).waveSize;
396349
int64_t blockSize = waveSize * accelParams0.getNPerBlock() *
397350
accelParams0.getMPerBlock() /
398351
(accelParams0.getMPerWave() * accelParams0.getNPerWave());
399-
auto populateParamsAccelPtr = PopulateParamsAccel::select(features);
400-
LLVM_DEBUG(llvm::dbgs() << "accelParams0=" << accelParams0 << "\n");
401-
LLVM_DEBUG(llvm::dbgs() << "accelParams1=" << accelParams1 << "\n");
402-
LogicalResult isValidBlockwiseGemm0 =
403-
populateParamsAccelPtr->isValidBlockwiseGemm(
404-
accelParams0, cast<MemRefType>(op.getAType()).getElementType(),
405-
cast<MemRefType>(op.getBType()).getElementType(),
406-
rock::getArchValue(op));
407-
LogicalResult isValidBlockwiseGemm1 =
408-
populateParamsAccelPtr->isValidBlockwiseGemm(
409-
accelParams1, cast<MemRefType>(op.getCType()).getElementType(),
410-
cast<MemRefType>(op.getCType()).getElementType(),
411-
rock::getArchValue(op));
412-
if (isValidBlockwiseGemm0.failed() || isValidBlockwiseGemm1.failed()) {
413-
op.emitError("The provided perf config is not valid");
414-
return signalPassFailure();
415-
}
416352

417353
IntegerAttr blockSizeAttr = builder.getI32IntegerAttr(blockSize);
418354
func::FuncOp funcOp = getOperation();

mlir/lib/Dialect/Rock/Tuning/GridwiseGemmParams.cpp

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -729,3 +729,82 @@ Attribute PopulateParamsWmma::getGemmParamsAttr(
729729
validParams.gemmScheduleVersion, validParams.outputSwizzle,
730730
validParams.gemmAThreadCopyMoreGemmK);
731731
}
732+
733+
static RockAccelTuningParamAttrInterface
734+
deriveGemm1TuningParams(OpBuilder &b,
735+
RockAccelTuningParamAttrInterface gemm0TuningParams,
736+
AttnPerfConfigAttr attnPerfConfig) {
737+
int64_t gemm1KPack = gemm0TuningParams.getKpack();
738+
if (auto gemm0XdlDerivedParams =
739+
dyn_cast<MfmaGemmParamsAttr>(gemm0TuningParams)) {
740+
return MfmaGemmParamsAttr::get(
741+
b.getContext(), gemm0TuningParams.getMPerBlock() / gemm1KPack,
742+
attnPerfConfig.getMPerBlockG1(), gemm0XdlDerivedParams.getNPerBlock(),
743+
gemm0TuningParams.getKpack(),
744+
gemm0TuningParams.getMPerWave() * (attnPerfConfig.getMPerBlockG1() /
745+
gemm0TuningParams.getMPerBlock()),
746+
gemm0XdlDerivedParams.getNPerWave(),
747+
gemm0XdlDerivedParams.getMnPerXdl(), attnPerfConfig.getSplitKFactor(),
748+
gemm0XdlDerivedParams.getScheduleVersion(),
749+
gemm0XdlDerivedParams.getOutputSwizzle(),
750+
gemm0XdlDerivedParams.getForceUnroll());
751+
}
752+
return WmmaGemmParamsAttr::get(
753+
b.getContext(), gemm0TuningParams.getMPerBlock() / gemm1KPack,
754+
attnPerfConfig.getMPerBlockG1(), attnPerfConfig.getNPerBlockG0(),
755+
gemm0TuningParams.getKpack(),
756+
gemm0TuningParams.getMPerWave() *
757+
(attnPerfConfig.getMPerBlockG1() / gemm0TuningParams.getMPerBlock()),
758+
gemm0TuningParams.getNPerWave(), gemm0TuningParams.getMnPerXdl(),
759+
attnPerfConfig.getSplitKFactor(), gemm0TuningParams.getScheduleVersion(),
760+
gemm0TuningParams.getOutputSwizzle(), gemm0TuningParams.getForceUnroll());
761+
}
762+
763+
FailureOr<std::pair<RockAccelTuningParamAttrInterface,
764+
RockAccelTuningParamAttrInterface>>
765+
mlir::rock::getAttentionTuningParams(OpBuilder &b,
766+
RockGemmGemmWrapperInterface op,
767+
AttnPerfConfigAttr attnPerfConfig) {
768+
GemmFeatures features = rock::getFeatures(op);
769+
RockAccelTuningParamAttrInterface accelParams0;
770+
if (bitEnumContainsAny(features, GemmFeatures::mfma)) {
771+
accelParams0 = MfmaGemmParamsAttr::get(
772+
b.getContext(), attnPerfConfig.getKpackPerBlock(),
773+
attnPerfConfig.getMPerBlockG0(), attnPerfConfig.getNPerBlockG0(),
774+
attnPerfConfig.getKpack(), attnPerfConfig.getMPerWave(),
775+
attnPerfConfig.getNPerWave(), attnPerfConfig.getMnPerXdl(), 1,
776+
attnPerfConfig.getScheduleVersion(), attnPerfConfig.getOutputSwizzle(),
777+
attnPerfConfig.getForceUnroll());
778+
} else {
779+
accelParams0 = WmmaGemmParamsAttr::get(
780+
b.getContext(), attnPerfConfig.getKpackPerBlock(),
781+
attnPerfConfig.getMPerBlockG0(), attnPerfConfig.getNPerBlockG0(),
782+
attnPerfConfig.getKpack(), attnPerfConfig.getMPerWave(),
783+
attnPerfConfig.getNPerWave(), attnPerfConfig.getMnPerXdl(), 1,
784+
attnPerfConfig.getScheduleVersion(), attnPerfConfig.getOutputSwizzle(),
785+
attnPerfConfig.getForceUnroll());
786+
}
787+
if (attnPerfConfig.getMPerBlockG1() % attnPerfConfig.getMPerBlockG0() != 0) {
788+
return failure();
789+
}
790+
if (attnPerfConfig.getMPerBlockG0() % attnPerfConfig.getKpack() != 0) {
791+
return failure();
792+
}
793+
RockAccelTuningParamAttrInterface accelParams1 =
794+
deriveGemm1TuningParams(b, accelParams0, attnPerfConfig);
795+
auto populateParamsAccelPtr = PopulateParamsAccel::select(features);
796+
LogicalResult isValidBlockwiseGemm0 =
797+
populateParamsAccelPtr->isValidBlockwiseGemm(
798+
accelParams0, cast<MemRefType>(op.getAType()).getElementType(),
799+
cast<MemRefType>(op.getBType()).getElementType(),
800+
rock::getArchValue(op));
801+
LogicalResult isValidBlockwiseGemm1 =
802+
populateParamsAccelPtr->isValidBlockwiseGemm(
803+
accelParams1, cast<MemRefType>(op.getCType()).getElementType(),
804+
cast<MemRefType>(op.getCType()).getElementType(),
805+
rock::getArchValue(op));
806+
if (isValidBlockwiseGemm0.failed() || isValidBlockwiseGemm1.failed()) {
807+
return failure();
808+
}
809+
return std::make_pair(accelParams0, accelParams1);
810+
}

0 commit comments

Comments
 (0)