@@ -1021,8 +1021,11 @@ LogicalResult ExtractSliceOfPadTensorSwapPattern::matchAndRewrite(
10211021 return success ();
10221022}
10231023
1024- // / Returns a tensor.pad op if padding value is set. Otherwise, returns the
1025- // / source directly. The method assumes that the `packOp` has static shapes.
1024+ // / If padding value is set, returns a tensor.pad Op for the source tensor,
1025+ // / with the output shape matching the output of `packOp`. Otherwise, returns
1026+ // / the source directly.
1027+ // /
1028+ // / This method assumes that all outer dims for this pack Op are 1.
10261029static Value getPackOpSourceOrPaddedSource (OpBuilder &builder,
10271030 tensor::PackOp packOp) {
10281031 Value input = packOp.getSource ();
@@ -1038,26 +1041,48 @@ static Value getPackOpSourceOrPaddedSource(OpBuilder &builder,
10381041 ShapedType inputType = packOp.getSourceType ();
10391042 int64_t inputRank = inputType.getRank ();
10401043
1041- SmallVector<int64_t > paddedShape;
10421044 DenseMap<int64_t , OpFoldResult> tileAndPosMapping =
10431045 packOp.getDimAndTileMapping ();
1044- for (int64_t dim = 0 ; dim < inputRank; ++dim) {
1045- int64_t size = inputType.getDimSize (dim);
1046- if (!tileAndPosMapping.count (dim)) {
1047- paddedShape.push_back (size);
1046+
1047+ // The sizes of dynamic tiles
1048+ SmallVector<Value> dynamicTileSizes;
1049+
1050+ // Collect dims for the padded shape.
1051+ SmallVector<int64_t > paddedShape;
1052+ for (int64_t dimIdx = 0 ; dimIdx < inputRank; ++dimIdx) {
1053+ // 1. Non-tiled outer dims.
1054+ // These dims should be 1 and we simply preserve them.
1055+ if (!tileAndPosMapping.count (dimIdx)) {
1056+ int64_t inputDimSize = inputType.getDimSize (dimIdx);
1057+ assert (inputDimSize == 1 &&
1058+ " with all outer dims == 1, this non-tiled input dim should be 1!" );
1059+ paddedShape.push_back (inputDimSize);
1060+ continue ;
1061+ }
1062+
1063+ // 2. Tiled outer dims
1064+ // As all outer dims == 1, it is safe to use the tile size for the padded
1065+ // shape.
1066+ OpFoldResult tileSizeForDim = tileAndPosMapping.lookup (dimIdx);
1067+
1068+ // 2.1 Static tile sizes
1069+ std::optional<int64_t > cstTileSize = getConstantIntValue (tileSizeForDim);
1070+ if (cstTileSize.has_value ()) {
1071+ paddedShape.push_back (cstTileSize.value ());
10481072 continue ;
10491073 }
10501074
1051- // The size is less than or equal to tileSize because outer dims are all 1s.
1052- std::optional< int64_t > tileSize =
1053- getConstantIntValue (tileAndPosMapping. lookup (dim));
1054- assert (tileSize. has_value () && " dynamic inner tile size is not supported " );
1055- paddedShape .push_back (tileSize. value ( ));
1075+ // 2.2 Dynamic tile sizes
1076+ paddedShape. push_back (ShapedType:: kDynamic );
1077+
1078+ // Get the value that holds the dynamic size.
1079+ dynamicTileSizes .push_back (llvm::dyn_cast<Value>(tileSizeForDim ));
10561080 }
10571081 auto resultType =
10581082 RankedTensorType::get (paddedShape, inputType.getElementType ());
10591083 return tensor::createPadHighOp (resultType, input, packOp.getPaddingValue (),
1060- /* nofold=*/ false , loc, builder);
1084+ /* nofold=*/ false , loc, builder,
1085+ dynamicTileSizes);
10611086}
10621087
10631088// Normalizes a permutation on a higher rank space to its actual size, e.g.
@@ -1120,10 +1145,10 @@ getPackUnpackRankReducedPerm(ArrayRef<int64_t> shape,
11201145
11211146LogicalResult GeneralizeOuterUnitDimsPackOpPattern::matchAndRewrite (
11221147 tensor::PackOp packOp, PatternRewriter &rewriter) const {
1123- if (llvm::any_of (packOp.getMixedTiles (),
1124- [](OpFoldResult tile) { return tile.is <Value>(); })) {
1125- return rewriter.notifyMatchFailure (packOp,
1126- " require inner tile sizes being static " );
1148+ if (llvm::count_if (packOp.getMixedTiles (),
1149+ [](OpFoldResult tile) { return tile.is <Value>(); }) > 1 ) {
1150+ return rewriter.notifyMatchFailure (
1151+ packOp, " at most one dynamic tile size is supported " );
11271152 }
11281153
11291154 // TODO: support the case that outer dimensions are not all 1s. A
@@ -1147,12 +1172,15 @@ LogicalResult GeneralizeOuterUnitDimsPackOpPattern::matchAndRewrite(
11471172 SmallVector<OpFoldResult> readOffsets (srcRank, zeroIdxAttr);
11481173 SmallVector<OpFoldResult> readStrides (srcRank, oneIdxAttr);
11491174 SmallVector<OpFoldResult> readSizes;
1150- SmallVector<int64_t > readShape;
1175+ SmallVector<OpFoldResult> transShapeForEmpty;
1176+ SmallVector<int64_t > readShapeForExtractSlice;
11511177 for (auto i : llvm::seq<unsigned >(0 , srcRank)) {
11521178 if (dimAndTileMapping.count (i)) {
1153- readShape.push_back (getConstantIntValue (dimAndTileMapping[i])
1154- .value_or (ShapedType::kDynamic ));
1179+ readShapeForExtractSlice.push_back (
1180+ getConstantIntValue (dimAndTileMapping[i])
1181+ .value_or (ShapedType::kDynamic ));
11551182 readSizes.push_back (dimAndTileMapping[i]);
1183+ transShapeForEmpty.push_back (dimAndTileMapping[i]);
11561184 continue ;
11571185 }
11581186 if (ShapedType::isDynamic (inputShape[i])) {
@@ -1161,12 +1189,14 @@ LogicalResult GeneralizeOuterUnitDimsPackOpPattern::matchAndRewrite(
11611189 } else {
11621190 readSizes.push_back (rewriter.getIndexAttr (inputShape[i]));
11631191 }
1164- if (inputShape[i] != 1 )
1165- readShape.push_back (inputShape[i]);
1192+ if (inputShape[i] != 1 ) {
1193+ readShapeForExtractSlice.push_back (inputShape[i]);
1194+ transShapeForEmpty.push_back (rewriter.getIndexAttr (inputShape[i]));
1195+ }
11661196 }
11671197
11681198 Type elemType = packOp.getSourceType ().getElementType ();
1169- auto readType = RankedTensorType::get (readShape , elemType);
1199+ auto readType = RankedTensorType::get (readShapeForExtractSlice , elemType);
11701200
11711201 Value tile = rewriter.create <tensor::ExtractSliceOp>(
11721202 loc, readType, input, readOffsets, readSizes, readStrides);
@@ -1178,10 +1208,10 @@ LogicalResult GeneralizeOuterUnitDimsPackOpPattern::matchAndRewrite(
11781208 LLVM_DEBUG (DBGS () << " Pack permutation: " << packOp << " \n " ;
11791209 llvm::interleaveComma (perm, DBGS () << " perm: " ); DBGSNL (););
11801210
1181- SmallVector<int64_t > transpShape = readShape;
1182- applyPermutationToVector<int64_t >(transpShape, perm);
1211+ applyPermutationToVector<OpFoldResult>(transShapeForEmpty, perm);
11831212
1184- Value empty = rewriter.create <tensor::EmptyOp>(loc, transpShape, elemType);
1213+ Value empty =
1214+ rewriter.create <tensor::EmptyOp>(loc, transShapeForEmpty, elemType);
11851215 auto transposedOp =
11861216 rewriter.create <linalg::TransposeOp>(loc, tile, empty, perm);
11871217
0 commit comments