55#include " mlir/Dialect/Rock/IR/Rock.h"
66#include " mlir/Dialect/Rock/IR/RockGemmGemmWrapperInterface.h"
77#include " mlir/Dialect/Rock/IR/RockGemmWrapperInterface.h"
8+ #include " mlir/Dialect/Rock/IR/RockTypes.h"
89#include " mlir/Dialect/Rock/Passes.h"
910#include " mlir/Dialect/Rock/Tuning/GridwiseGemmParams.h"
1011#include " mlir/Dialect/Rock/Tuning/UtilityParams.h"
1920#include " llvm/Support/Debug.h"
2021#include " llvm/Support/LogicalResult.h"
2122#include " llvm/Support/raw_ostream.h"
23+ #include < optional>
2224
2325namespace mlir {
2426namespace rock {
@@ -51,6 +53,41 @@ struct AffixTuningParameters
5153};
5254} // anonymous namespace
5355
56+ static FailureOr<std::optional<int64_t >> getScheduleVersion (func::FuncOp funcOp,
57+ Operation *op) {
58+ auto scheduleVersionAttrName = rock::ScheduleVersionAttr::getMnemonic ();
59+
60+ std::optional<int64_t > scheduleVersion = std::nullopt ;
61+ bool hasPerfConfig = op->hasAttrOfType <StringAttr>(" perf_config" );
62+ if (funcOp->hasAttrOfType <rock::ScheduleVersionAttr>(
63+ scheduleVersionAttrName) &&
64+ hasPerfConfig) {
65+ return op->emitError (
66+ " kernel has both perf_config and schedule_version attribute "
67+ " set. Please modify schedule version directly inside "
68+ " perf_config and remove schedule_version\n " );
69+ }
70+ if (funcOp->hasAttrOfType <rock::ScheduleVersionAttr>(
71+ scheduleVersionAttrName)) {
72+ scheduleVersion = dyn_cast<rock::ScheduleVersionAttr>(
73+ funcOp->removeAttr (scheduleVersionAttrName))
74+ .getScheduleVersion ();
75+ } else if (!hasPerfConfig) {
76+ // set default schedule
77+ scheduleVersion = static_cast <int64_t >(GemmLoadTileType::Default);
78+ }
79+
80+ // check scheduleVersion is valid
81+ if (scheduleVersion.has_value ()) {
82+ std::optional<GemmLoadTileType> maybeLoadType =
83+ rock::symbolizeGemmLoadTileType (scheduleVersion.value ());
84+ if (!maybeLoadType.has_value ())
85+ return op->emitOpError (" schedule version value is incorrect" );
86+ }
87+
88+ return scheduleVersion;
89+ }
90+
5491void AffixTuningParameters::runOnOperation () {
5592 func::FuncOp func = getOperation ();
5693 // currently, in rocMLIR we only support one Fusion Root per function.
@@ -128,33 +165,37 @@ void AffixTuningParameters::setUtilityKernelSizes(Value arg, T utilityOp) {
128165 funcOp->setAttr (" grid_size" , gridSizeAttr);
129166}
130167
168+ static LogicalResult isScheduleVersionSupported (int64_t scheduleVersion,
169+ GemmFeatures features) {
170+ std::optional<GemmLoadTileType> maybeLoadType =
171+ rock::symbolizeGemmLoadTileType (scheduleVersion);
172+ if (!maybeLoadType.has_value ())
173+ return failure ();
174+
175+ auto loadType = maybeLoadType.value ();
176+ bool directToLDS = loadType == GemmLoadTileType::DirectToLDSDefault ||
177+ loadType == GemmLoadTileType::DirectToLDSDoubleBuffer;
178+ if (directToLDS && !isDirectToLDSSupported (features))
179+ return failure ();
180+
181+ return success ();
182+ }
183+
131184void AffixTuningParameters::affixTuningParametersImpl (
132185 RockGemmWrapperInterface op) {
133186 OpBuilder b (op.getContext ());
134- auto scheduleVersionAttrName = rock::ScheduleVersionAttr::getMnemonic ();
135187 auto funcParent = op->getParentOfType <func::FuncOp>();
136188 std::string perfConfig;
137- if (funcParent->hasAttrOfType <rock::ScheduleVersionAttr>(
138- scheduleVersionAttrName) &&
139- op->hasAttrOfType <StringAttr>(" perf_config" )) {
140- op->emitError (" kernel has both perf_config and schedule_version attribute "
141- " set. Please modify schedule version directly inside "
142- " perf_config and remove schedule_version\n " );
143- signalPassFailure ();
144- return ;
145- }
146189 if (auto perfConfigAttr =
147190 op->template getAttrOfType <StringAttr>(" perf_config" )) {
148191 perfConfig = perfConfigAttr.getValue ().str ();
149192 }
150- // by default rocMLIR selects GEMM Schedule V1
151- auto scheduleVersion = 1 ;
152- if (funcParent->hasAttrOfType <rock::ScheduleVersionAttr>(
153- scheduleVersionAttrName)) {
154- scheduleVersion = dyn_cast<rock::ScheduleVersionAttr>(
155- funcParent->removeAttr (scheduleVersionAttrName))
156- .getScheduleVersion ();
157- }
193+ FailureOr<std::optional<int64_t >> maybeScheduleVersion =
194+ getScheduleVersion (funcParent, op);
195+ if (failed (maybeScheduleVersion))
196+ return signalPassFailure ();
197+
198+ std::optional<int64_t > scheduleVersion = maybeScheduleVersion.value ();
158199
159200 GemmFeatures features = rock::getFeatures (op);
160201 if (isAccel (features)) {
@@ -165,9 +206,15 @@ void AffixTuningParameters::affixTuningParametersImpl(
165206 // update schedule version to what is provided by the user if and only if
166207 // user hasn't provided perfConfig, otherwise just keep whatever is inside
167208 // perfConfig
168- if (!op->hasAttrOfType <StringAttr>(" perf_config" )) {
169- validParams.gemmScheduleVersion = scheduleVersion;
209+ if (scheduleVersion.has_value ())
210+ validParams.gemmScheduleVersion = scheduleVersion.value ();
211+
212+ if (failed (isScheduleVersionSupported (validParams.gemmScheduleVersion ,
213+ features))) {
214+ op->emitError (" schedule version not supported\n " );
215+ return signalPassFailure ();
170216 }
217+
171218 if (failed (status)) {
172219 // Try again if allowed.
173220 if (fallBackNoConfig) {
@@ -233,9 +280,8 @@ void AffixTuningParameters::affixTuningParametersImpl(
233280 // update schedule version to what is provided by the user if and only if
234281 // user hasn't provided perfConfig, otherwise just keep whatever was
235282 // obtained from perfConfig
236- if (!op->hasAttrOfType <StringAttr>(" perf_config" )) {
237- validParams.gemmScheduleVersion = scheduleVersion;
238- }
283+ if (scheduleVersion.has_value ())
284+ validParams.gemmScheduleVersion = scheduleVersion.value ();
239285
240286 Attribute gemmParams = populateParams.getGemmParamsAttr (b, validParams);
241287 op.setGemmParamsAttr (gemmParams);
@@ -289,6 +335,13 @@ void AffixTuningParameters::affixTuningParametersImpl(
289335 " with matrix accelerator extentions" );
290336 return signalPassFailure ();
291337 }
338+ auto funcParent = op->getParentOfType <func::FuncOp>();
339+ FailureOr<std::optional<int64_t >> maybeScheduleVersion =
340+ getScheduleVersion (funcParent, op);
341+ if (failed (maybeScheduleVersion))
342+ return signalPassFailure ();
343+
344+ std::optional<int64_t > scheduleVersion = maybeScheduleVersion.value ();
292345
293346 Attribute params0 = op.getGemm0Params ().value_or (nullptr );
294347 // set a default one if params is not provided
@@ -305,6 +358,17 @@ void AffixTuningParameters::affixTuningParametersImpl(
305358 op.emitError (" perf config string has an incorrect format." );
306359 return signalPassFailure ();
307360 }
361+
362+ if (scheduleVersion.has_value ())
363+ attnPerfConfig =
364+ attnPerfConfig.withScheduleVersion (scheduleVersion.value ());
365+
366+ if (failed (isScheduleVersionSupported (attnPerfConfig.getScheduleVersion (),
367+ rock::getFeatures (op)))) {
368+ op->emitError (" schedule version not supported\n " );
369+ return signalPassFailure ();
370+ }
371+
308372 GemmFeatures features = rock::getFeatures (op);
309373 RockAccelTuningParamAttrInterface accelParams0;
310374 if (bitEnumContainsAny (features, GemmFeatures::mfma)) {
0 commit comments