Skip to content

Commit 5a601b9

Browse files
authored
Pipelining for attention (#1990)
1 parent b3011c6 commit 5a601b9

File tree

48 files changed

+1455
-743
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

48 files changed

+1455
-743
lines changed

mlir/include/mlir/Dialect/Rock/IR/AmdArchDb.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ struct AmdArchInfo {
5252
};
5353

5454
AmdArchInfo lookupArchInfo(StringRef arch);
55+
bool isDirectToLDSSupported(GemmFeatures features);
5556
} // namespace rock
5657
} // namespace mlir
5758

mlir/include/mlir/Dialect/Rock/IR/RockAttrDefs.td

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,12 @@ def Rock_AttnPerfConfig : Rock_Attr<"AttnPerfConfig", [RockTuningParamAttrInterf
297297
+ Twine(getOutputSwizzle()) + ","
298298
+ Twine(getForceUnroll())).toVector(perfStr);
299299
}
300+
AttnPerfConfigAttr withScheduleVersion(int64_t newScheduleVersion) const {
301+
return AttnPerfConfigAttr::get(
302+
getContext(), getMPerBlockG0(), getMPerBlockG1(), getNPerBlockG0(),
303+
getKpackPerBlock(), getMPerWave(), getMnPerXdl(), getKpack(),
304+
getSplitKFactor(), newScheduleVersion, getOutputSwizzle(), getForceUnroll());
305+
}
300306
}];
301307

302308
let builders = [

mlir/lib/Dialect/Rock/IR/AmdArchDb.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -413,3 +413,8 @@ GemmFeatures mlir::rock::AmdArchInfo::getDefaultFeatures(Type dataType) {
413413
}
414414
return theseFeatures;
415415
}
416+
417+
bool mlir::rock::isDirectToLDSSupported(GemmFeatures features) {
418+
return bitEnumContainsAll(features, GemmFeatures::direct_to_lds_128b) ||
419+
bitEnumContainsAll(features, GemmFeatures::direct_to_lds_32b);
420+
}

mlir/lib/Dialect/Rock/Transforms/AffixTuningParameters.cpp

Lines changed: 87 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include "mlir/Dialect/Rock/IR/Rock.h"
66
#include "mlir/Dialect/Rock/IR/RockGemmGemmWrapperInterface.h"
77
#include "mlir/Dialect/Rock/IR/RockGemmWrapperInterface.h"
8+
#include "mlir/Dialect/Rock/IR/RockTypes.h"
89
#include "mlir/Dialect/Rock/Passes.h"
910
#include "mlir/Dialect/Rock/Tuning/GridwiseGemmParams.h"
1011
#include "mlir/Dialect/Rock/Tuning/UtilityParams.h"
@@ -19,6 +20,7 @@
1920
#include "llvm/Support/Debug.h"
2021
#include "llvm/Support/LogicalResult.h"
2122
#include "llvm/Support/raw_ostream.h"
23+
#include <optional>
2224

2325
namespace mlir {
2426
namespace rock {
@@ -51,6 +53,41 @@ struct AffixTuningParameters
5153
};
5254
} // anonymous namespace
5355

56+
static FailureOr<std::optional<int64_t>> getScheduleVersion(func::FuncOp funcOp,
57+
Operation *op) {
58+
auto scheduleVersionAttrName = rock::ScheduleVersionAttr::getMnemonic();
59+
60+
std::optional<int64_t> scheduleVersion = std::nullopt;
61+
bool hasPerfConfig = op->hasAttrOfType<StringAttr>("perf_config");
62+
if (funcOp->hasAttrOfType<rock::ScheduleVersionAttr>(
63+
scheduleVersionAttrName) &&
64+
hasPerfConfig) {
65+
return op->emitError(
66+
"kernel has both perf_config and schedule_version attribute "
67+
"set. Please modify schedule version directly inside "
68+
"perf_config and remove schedule_version\n");
69+
}
70+
if (funcOp->hasAttrOfType<rock::ScheduleVersionAttr>(
71+
scheduleVersionAttrName)) {
72+
scheduleVersion = dyn_cast<rock::ScheduleVersionAttr>(
73+
funcOp->removeAttr(scheduleVersionAttrName))
74+
.getScheduleVersion();
75+
} else if (!hasPerfConfig) {
76+
// set default schedule
77+
scheduleVersion = static_cast<int64_t>(GemmLoadTileType::Default);
78+
}
79+
80+
// check scheduleVersion is valid
81+
if (scheduleVersion.has_value()) {
82+
std::optional<GemmLoadTileType> maybeLoadType =
83+
rock::symbolizeGemmLoadTileType(scheduleVersion.value());
84+
if (!maybeLoadType.has_value())
85+
return op->emitOpError("schedule version value is incorrect");
86+
}
87+
88+
return scheduleVersion;
89+
}
90+
5491
void AffixTuningParameters::runOnOperation() {
5592
func::FuncOp func = getOperation();
5693
// currently, in rocMLIR we only support one Fusion Root per function.
@@ -128,33 +165,37 @@ void AffixTuningParameters::setUtilityKernelSizes(Value arg, T utilityOp) {
128165
funcOp->setAttr("grid_size", gridSizeAttr);
129166
}
130167

168+
static LogicalResult isScheduleVersionSupported(int64_t scheduleVersion,
169+
GemmFeatures features) {
170+
std::optional<GemmLoadTileType> maybeLoadType =
171+
rock::symbolizeGemmLoadTileType(scheduleVersion);
172+
if (!maybeLoadType.has_value())
173+
return failure();
174+
175+
auto loadType = maybeLoadType.value();
176+
bool directToLDS = loadType == GemmLoadTileType::DirectToLDSDefault ||
177+
loadType == GemmLoadTileType::DirectToLDSDoubleBuffer;
178+
if (directToLDS && !isDirectToLDSSupported(features))
179+
return failure();
180+
181+
return success();
182+
}
183+
131184
void AffixTuningParameters::affixTuningParametersImpl(
132185
RockGemmWrapperInterface op) {
133186
OpBuilder b(op.getContext());
134-
auto scheduleVersionAttrName = rock::ScheduleVersionAttr::getMnemonic();
135187
auto funcParent = op->getParentOfType<func::FuncOp>();
136188
std::string perfConfig;
137-
if (funcParent->hasAttrOfType<rock::ScheduleVersionAttr>(
138-
scheduleVersionAttrName) &&
139-
op->hasAttrOfType<StringAttr>("perf_config")) {
140-
op->emitError("kernel has both perf_config and schedule_version attribute "
141-
"set. Please modify schedule version directly inside "
142-
"perf_config and remove schedule_version\n");
143-
signalPassFailure();
144-
return;
145-
}
146189
if (auto perfConfigAttr =
147190
op->template getAttrOfType<StringAttr>("perf_config")) {
148191
perfConfig = perfConfigAttr.getValue().str();
149192
}
150-
// by default rocMLIR selects GEMM Schedule V1
151-
auto scheduleVersion = 1;
152-
if (funcParent->hasAttrOfType<rock::ScheduleVersionAttr>(
153-
scheduleVersionAttrName)) {
154-
scheduleVersion = dyn_cast<rock::ScheduleVersionAttr>(
155-
funcParent->removeAttr(scheduleVersionAttrName))
156-
.getScheduleVersion();
157-
}
193+
FailureOr<std::optional<int64_t>> maybeScheduleVersion =
194+
getScheduleVersion(funcParent, op);
195+
if (failed(maybeScheduleVersion))
196+
return signalPassFailure();
197+
198+
std::optional<int64_t> scheduleVersion = maybeScheduleVersion.value();
158199

159200
GemmFeatures features = rock::getFeatures(op);
160201
if (isAccel(features)) {
@@ -165,9 +206,15 @@ void AffixTuningParameters::affixTuningParametersImpl(
165206
// update schedule version to what is provided by the user if and only if
166207
// user hasn't provided perfConfig, otherwise just keep whatever is inside
167208
// perfConfig
168-
if (!op->hasAttrOfType<StringAttr>("perf_config")) {
169-
validParams.gemmScheduleVersion = scheduleVersion;
209+
if (scheduleVersion.has_value())
210+
validParams.gemmScheduleVersion = scheduleVersion.value();
211+
212+
if (failed(isScheduleVersionSupported(validParams.gemmScheduleVersion,
213+
features))) {
214+
op->emitError("schedule version not supported\n");
215+
return signalPassFailure();
170216
}
217+
171218
if (failed(status)) {
172219
// Try again if allowed.
173220
if (fallBackNoConfig) {
@@ -233,9 +280,8 @@ void AffixTuningParameters::affixTuningParametersImpl(
233280
// update schedule version to what is provided by the user if and only if
234281
// user hasn't provided perfConfig, otherwise just keep whatever was
235282
// obtained from perfConfig
236-
if (!op->hasAttrOfType<StringAttr>("perf_config")) {
237-
validParams.gemmScheduleVersion = scheduleVersion;
238-
}
283+
if (scheduleVersion.has_value())
284+
validParams.gemmScheduleVersion = scheduleVersion.value();
239285

240286
Attribute gemmParams = populateParams.getGemmParamsAttr(b, validParams);
241287
op.setGemmParamsAttr(gemmParams);
@@ -289,6 +335,13 @@ void AffixTuningParameters::affixTuningParametersImpl(
289335
"with matrix accelerator extentions");
290336
return signalPassFailure();
291337
}
338+
auto funcParent = op->getParentOfType<func::FuncOp>();
339+
FailureOr<std::optional<int64_t>> maybeScheduleVersion =
340+
getScheduleVersion(funcParent, op);
341+
if (failed(maybeScheduleVersion))
342+
return signalPassFailure();
343+
344+
std::optional<int64_t> scheduleVersion = maybeScheduleVersion.value();
292345

293346
Attribute params0 = op.getGemm0Params().value_or(nullptr);
294347
// set a default one if params is not provided
@@ -305,6 +358,17 @@ void AffixTuningParameters::affixTuningParametersImpl(
305358
op.emitError("perf config string has an incorrect format.");
306359
return signalPassFailure();
307360
}
361+
362+
if (scheduleVersion.has_value())
363+
attnPerfConfig =
364+
attnPerfConfig.withScheduleVersion(scheduleVersion.value());
365+
366+
if (failed(isScheduleVersionSupported(attnPerfConfig.getScheduleVersion(),
367+
rock::getFeatures(op)))) {
368+
op->emitError("schedule version not supported\n");
369+
return signalPassFailure();
370+
}
371+
308372
GemmFeatures features = rock::getFeatures(op);
309373
RockAccelTuningParamAttrInterface accelParams0;
310374
if (bitEnumContainsAny(features, GemmFeatures::mfma)) {

0 commit comments

Comments
 (0)