Skip to content

Commit 43e7c78

Browse files
committed
hack tiling for unpack sve
Signed-off-by: Ege Beysel <[email protected]>
1 parent 6fc1deb commit 43e7c78

File tree

1 file changed

+49
-3
lines changed

1 file changed

+49
-3
lines changed

mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp

Lines changed: 49 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,14 @@
2020
#include "mlir/Dialect/Utils/IndexingUtils.h"
2121
#include "mlir/Dialect/Utils/StaticValueUtils.h"
2222
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
23+
#include "mlir/Dialect/Vector/IR/VectorOps.h"
2324
#include "mlir/IR/BuiltinTypeInterfaces.h"
25+
#include "mlir/IR/Value.h"
2426
#include "mlir/Interfaces/TilingInterface.h"
2527
#include "mlir/Interfaces/ValueBoundsOpInterface.h"
2628
#include "llvm/Support/Debug.h"
29+
#include "llvm/Support/LogicalResult.h"
30+
#include <cstdint>
2731
#include <optional>
2832

2933
#define DEBUG_TYPE "linalg-tiling-interface-impl"
@@ -1036,6 +1040,25 @@ struct UnpackTileDimInfo {
10361040
OpFoldResult destExpandedSize;
10371041
};
10381042

1043+
FailureOr<int64_t> getStaticPartOfScalableTileSize(Operation *op) {
1044+
auto mulIOp = dyn_cast<arith::MulIOp>(op);
1045+
if (!mulIOp)
1046+
return failure();
1047+
1048+
auto lhs = mulIOp.getLhs().getDefiningOp();
1049+
auto rhs = mulIOp.getRhs().getDefiningOp();
1050+
1051+
auto cstOp = isa<arith::ConstantOp>(lhs) ? cast<arith::ConstantOp>(lhs)
1052+
: dyn_cast<arith::ConstantOp>(rhs);
1053+
if (!cstOp)
1054+
return failure();
1055+
if (!isa<vector::VectorScaleOp>(lhs) && !isa<vector::VectorScaleOp>(rhs))
1056+
return failure();
1057+
if (auto integerAttr = dyn_cast<IntegerAttr>(cstOp.getValue()))
1058+
return integerAttr.getInt();
1059+
return failure();
1060+
}
1061+
10391062
/// Returns the needed information for tiling unpack op on `tileDim` with given
10401063
/// `tileOffset` and `tileSize`. For more details, see the comment of the
10411064
/// `getTiledImplementation`.
@@ -1072,13 +1095,36 @@ static UnpackTileDimInfo getUnpackTileDimInfo(OpBuilder &b, UnPackOp unpackOp,
10721095
presburger::BoundType::UB, tileSize,
10731096
/*stopCondition=*/nullptr, /*closedUB=*/true);
10741097
std::optional<int64_t> cstInnerSize = getConstantIntValue(innerTileSize);
1075-
if (!failed(cstSize) && cstInnerSize) {
1076-
if (*cstSize % *cstInnerSize == 0)
1098+
bool assumeInnerTileSizesMatchTiles = false;
1099+
if (!cstInnerSize) {
1100+
// Warning: Hard-coded separation of distribution and vector level tiling :D
1101+
Value scalableInnerTileSize = cast<Value>(innerTileSize);
1102+
auto staticInnerTileSize =
1103+
getStaticPartOfScalableTileSize(scalableInnerTileSize.getDefiningOp());
1104+
info.isAlignedToInnerTileSize = succeeded(staticInnerTileSize);
1105+
if (auto tileSizeVal = dyn_cast<Value>(tileSize)) {
1106+
if (auto tileSizeAffineOp =
1107+
tileSizeVal.getDefiningOp<affine::AffineMinOp>()) {
1108+
if (tileSizeAffineOp->getNumOperands() >= 2) {
1109+
Value scalableTileSize = tileSizeAffineOp->getOperand(tileSizeAffineOp->getNumOperands() - 1);
1110+
auto staticTileSize =
1111+
getStaticPartOfScalableTileSize(scalableTileSize.getDefiningOp());
1112+
assumeInnerTileSizesMatchTiles =
1113+
succeeded(staticTileSize) && succeeded(staticInnerTileSize) &&
1114+
staticTileSize.value() == staticInnerTileSize.value();
1115+
}
1116+
}
1117+
}
1118+
}
1119+
if (!failed(cstSize) && (cstInnerSize || assumeInnerTileSizesMatchTiles ||
1120+
info.isAlignedToInnerTileSize)) {
1121+
if (assumeInnerTileSizesMatchTiles || *cstSize % *cstInnerSize == 0)
10771122
info.isAlignedToInnerTileSize = true;
10781123

10791124
// If the tiling size equals to the inner tiling size, the outer dims are
10801125
// always 1.
1081-
if (*cstInnerSize == *cstSize) {
1126+
if (assumeInnerTileSizesMatchTiles ||
1127+
(cstInnerSize && *cstInnerSize == *cstSize)) {
10821128
auto lhs = AV(dim0).bind(tileOffset);
10831129
auto rhs = AV(dim1).bind(innerTileSize);
10841130
info.sourceOffset = ab.floor(lhs, rhs);

0 commit comments

Comments
 (0)