Skip to content

Commit f2234f6

Browse files
committed
hack tiling for pack sve
Signed-off-by: Ege Beysel <[email protected]>
1 parent 43e7c78 commit f2234f6

File tree

1 file changed

+44
-20
lines changed

1 file changed

+44
-20
lines changed

mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp

Lines changed: 44 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -724,6 +724,25 @@ static void applyPermToRange(SmallVector<OpFoldResult> &offsets,
724724
applyPermutationToVector<OpFoldResult>(sizes, permutation);
725725
}
726726

727+
static FailureOr<int64_t> getStaticPartOfScalableTileSize(Operation *op) {
728+
auto mulIOp = dyn_cast<arith::MulIOp>(op);
729+
if (!mulIOp)
730+
return failure();
731+
732+
auto lhs = mulIOp.getLhs().getDefiningOp();
733+
auto rhs = mulIOp.getRhs().getDefiningOp();
734+
735+
auto cstOp = isa<arith::ConstantOp>(lhs) ? cast<arith::ConstantOp>(lhs)
736+
: dyn_cast<arith::ConstantOp>(rhs);
737+
if (!cstOp)
738+
return failure();
739+
if (!isa<vector::VectorScaleOp>(lhs) && !isa<vector::VectorScaleOp>(rhs))
740+
return failure();
741+
if (auto integerAttr = dyn_cast<IntegerAttr>(cstOp.getValue()))
742+
return integerAttr.getInt();
743+
return failure();
744+
}
745+
727746
struct PackOpTiling
728747
: public TilingInterface::ExternalModel<PackOpTiling, linalg::PackOp> {
729748

@@ -950,7 +969,30 @@ struct PackOpTiling
950969
// another word, we can only support tiling with consumer if the tile
951970
// size for the producer is a multiple of the inner tile size for the
952971
// packed dimensions at this moment.
953-
if ((failed(cstTileSize) || !cstInnerSize ||
972+
bool assumeInnerTileSizesMatchTiles = false;
973+
bool isAlignedToInnerTileSize = false;
974+
if (!cstInnerSize) {
975+
// Warning: Hard-coded separation of distribution and vector level tiling :D
976+
Value scalableInnerTileSize = cast<Value>(dimAndTileMapping[dim]);
977+
auto staticInnerTileSize =
978+
getStaticPartOfScalableTileSize(scalableInnerTileSize.getDefiningOp());
979+
isAlignedToInnerTileSize = succeeded(staticInnerTileSize);
980+
if (auto tileSizeVal = dyn_cast<Value>(sizes[dim])) {
981+
if (auto tileSizeAffineOp =
982+
tileSizeVal.getDefiningOp<affine::AffineMinOp>()) {
983+
if (tileSizeAffineOp->getNumOperands() >= 2) {
984+
Value scalableTileSize = tileSizeAffineOp->getOperand(tileSizeAffineOp->getNumOperands() - 1);
985+
auto staticTileSize =
986+
getStaticPartOfScalableTileSize(scalableTileSize.getDefiningOp());
987+
assumeInnerTileSizesMatchTiles =
988+
succeeded(staticTileSize) && succeeded(staticInnerTileSize) &&
989+
staticTileSize.value() == staticInnerTileSize.value();
990+
}
991+
}
992+
}
993+
}
994+
// TODO: here comes the alignment flag
995+
if (!isAlignedToInnerTileSize && (failed(cstTileSize) || !cstInnerSize ||
954996
*cstTileSize % *cstInnerSize != 0))
955997
return failure();
956998

@@ -963,6 +1005,7 @@ struct PackOpTiling
9631005
auto avSize = AV(dim0).bind(sizes[dim]);
9641006
auto avTileSize = AV(sym).bind(dimAndTileMapping[dim]);
9651007
outerDimOffsets.push_back(ab.floor(avOffset, avTileSize));
1008+
// TODO: here comes the equality one and 1 outer dim :)
9661009
outerDimSizes.push_back(ab.ceil(avSize, avTileSize));
9671010
} else {
9681011
outerDimOffsets.push_back(offsets[dim]);
@@ -1040,25 +1083,6 @@ struct UnpackTileDimInfo {
10401083
OpFoldResult destExpandedSize;
10411084
};
10421085

1043-
FailureOr<int64_t> getStaticPartOfScalableTileSize(Operation *op) {
1044-
auto mulIOp = dyn_cast<arith::MulIOp>(op);
1045-
if (!mulIOp)
1046-
return failure();
1047-
1048-
auto lhs = mulIOp.getLhs().getDefiningOp();
1049-
auto rhs = mulIOp.getRhs().getDefiningOp();
1050-
1051-
auto cstOp = isa<arith::ConstantOp>(lhs) ? cast<arith::ConstantOp>(lhs)
1052-
: dyn_cast<arith::ConstantOp>(rhs);
1053-
if (!cstOp)
1054-
return failure();
1055-
if (!isa<vector::VectorScaleOp>(lhs) && !isa<vector::VectorScaleOp>(rhs))
1056-
return failure();
1057-
if (auto integerAttr = dyn_cast<IntegerAttr>(cstOp.getValue()))
1058-
return integerAttr.getInt();
1059-
return failure();
1060-
}
1061-
10621086
/// Returns the needed information for tiling unpack op on `tileDim` with given
10631087
/// `tileOffset` and `tileSize`. For more details, see the comment of the
10641088
/// `getTiledImplementation`.

0 commit comments

Comments
 (0)