|
20 | 20 | #include "mlir/Dialect/Utils/IndexingUtils.h" |
21 | 21 | #include "mlir/Dialect/Utils/StaticValueUtils.h" |
22 | 22 | #include "mlir/Dialect/Utils/StructuredOpsUtils.h" |
| 23 | +#include "mlir/Dialect/Vector/IR/VectorOps.h" |
23 | 24 | #include "mlir/IR/BuiltinTypeInterfaces.h" |
| 25 | +#include "mlir/IR/Value.h" |
24 | 26 | #include "mlir/Interfaces/TilingInterface.h" |
25 | 27 | #include "mlir/Interfaces/ValueBoundsOpInterface.h" |
26 | 28 | #include "llvm/Support/Debug.h" |
| 29 | +#include "llvm/Support/LogicalResult.h" |
| 30 | +#include <cstdint> |
27 | 31 | #include <optional> |
28 | 32 |
|
29 | 33 | #define DEBUG_TYPE "linalg-tiling-interface-impl" |
@@ -1036,6 +1040,25 @@ struct UnpackTileDimInfo { |
1036 | 1040 | OpFoldResult destExpandedSize; |
1037 | 1041 | }; |
1038 | 1042 |
|
| 1043 | +FailureOr<int64_t> getStaticPartOfScalableTileSize(Operation *op) { |
| 1044 | + auto mulIOp = dyn_cast<arith::MulIOp>(op); |
| 1045 | + if (!mulIOp) |
| 1046 | + return failure(); |
| 1047 | + |
| 1048 | + auto lhs = mulIOp.getLhs().getDefiningOp(); |
| 1049 | + auto rhs = mulIOp.getRhs().getDefiningOp(); |
| 1050 | + |
| 1051 | + auto cstOp = isa<arith::ConstantOp>(lhs) ? cast<arith::ConstantOp>(lhs) |
| 1052 | + : dyn_cast<arith::ConstantOp>(rhs); |
| 1053 | + if (!cstOp) |
| 1054 | + return failure(); |
| 1055 | + if (!isa<vector::VectorScaleOp>(lhs) && !isa<vector::VectorScaleOp>(rhs)) |
| 1056 | + return failure(); |
| 1057 | + if (auto integerAttr = dyn_cast<IntegerAttr>(cstOp.getValue())) |
| 1058 | + return integerAttr.getInt(); |
| 1059 | + return failure(); |
| 1060 | +} |
| 1061 | + |
1039 | 1062 | /// Returns the needed information for tiling unpack op on `tileDim` with given |
1040 | 1063 | /// `tileOffset` and `tileSize`. For more details, see the comment of the |
1041 | 1064 | /// `getTiledImplementation`. |
@@ -1072,13 +1095,36 @@ static UnpackTileDimInfo getUnpackTileDimInfo(OpBuilder &b, UnPackOp unpackOp, |
1072 | 1095 | presburger::BoundType::UB, tileSize, |
1073 | 1096 | /*stopCondition=*/nullptr, /*closedUB=*/true); |
1074 | 1097 | std::optional<int64_t> cstInnerSize = getConstantIntValue(innerTileSize); |
1075 | | - if (!failed(cstSize) && cstInnerSize) { |
1076 | | - if (*cstSize % *cstInnerSize == 0) |
| 1098 | + bool assumeInnerTileSizesMatchTiles = false; |
| 1099 | + if (!cstInnerSize) { |
| 1100 | + // Warning: Hard-coded separation of distribution and vector level tiling :D |
| 1101 | + Value scalableInnerTileSize = cast<Value>(innerTileSize); |
| 1102 | + auto staticInnerTileSize = |
| 1103 | + getStaticPartOfScalableTileSize(scalableInnerTileSize.getDefiningOp()); |
| 1104 | + info.isAlignedToInnerTileSize = succeeded(staticInnerTileSize); |
| 1105 | + if (auto tileSizeVal = dyn_cast<Value>(tileSize)) { |
| 1106 | + if (auto tileSizeAffineOp = |
| 1107 | + tileSizeVal.getDefiningOp<affine::AffineMinOp>()) { |
| 1108 | + if (tileSizeAffineOp->getNumOperands() >= 2) { |
| 1109 | + Value scalableTileSize = tileSizeAffineOp->getOperand(tileSizeAffineOp->getNumOperands() - 1); |
| 1110 | + auto staticTileSize = |
| 1111 | + getStaticPartOfScalableTileSize(scalableTileSize.getDefiningOp()); |
| 1112 | + assumeInnerTileSizesMatchTiles = |
| 1113 | + succeeded(staticTileSize) && succeeded(staticInnerTileSize) && |
| 1114 | + staticTileSize.value() == staticInnerTileSize.value(); |
| 1115 | + } |
| 1116 | + } |
| 1117 | + } |
| 1118 | + } |
| 1119 | + if (!failed(cstSize) && (cstInnerSize || assumeInnerTileSizesMatchTiles || |
| 1120 | + info.isAlignedToInnerTileSize)) { |
| 1121 | + if (assumeInnerTileSizesMatchTiles || *cstSize % *cstInnerSize == 0) |
1077 | 1122 | info.isAlignedToInnerTileSize = true; |
1078 | 1123 |
|
1079 | 1124 | // If the tiling size equals to the inner tiling size, the outer dims are |
1080 | 1125 | // always 1. |
1081 | | - if (*cstInnerSize == *cstSize) { |
| 1126 | + if (assumeInnerTileSizesMatchTiles || |
| 1127 | + (cstInnerSize && *cstInnerSize == *cstSize)) { |
1082 | 1128 | auto lhs = AV(dim0).bind(tileOffset); |
1083 | 1129 | auto rhs = AV(dim1).bind(innerTileSize); |
1084 | 1130 | info.sourceOffset = ab.floor(lhs, rhs); |
|
0 commit comments