@@ -1146,38 +1146,25 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
11461146 Attribute oneIdxAttr = rewriter.getIndexAttr (1 );
11471147 Location loc = packOp.getLoc ();
11481148
1149- Value input = getPackOpSourceOrPaddedSource (rewriter, packOp);
1150- DenseMap<int64_t , OpFoldResult> dimAndTileMapping =
1151- packOp.getDimAndTileMapping ();
11521149 int64_t srcRank = packOp.getSourceRank ();
11531150 int64_t destRank = packOp.getDestRank ();
1151+ ArrayRef<int64_t > innerDimsPos = packOp.getInnerDimsPos ();
1152+ int64_t numberOfTiles = innerDimsPos.size ();
11541153
1155- // 1. Extract the inner tile sizes and the shapes for the tensor.empty op
1156- // before transposing. Where possible, values are replaced with constant
1157- // attributes (to match the behaviour of `getPackOpSourceOrPaddedSource`).
1158- SmallVector<OpFoldResult> transShapeForEmptyOp (srcRank, oneIdxAttr);
1159- SmallVector<OpFoldResult> tileSizes;
1160- for (auto i : llvm::seq<unsigned >(0 , srcRank)) {
1161- if (dimAndTileMapping.count (i)) {
1162- // Rather than taking the tile size as is, extact the actual constant
1163- // value Attribute where possible, e.g.:
1164- // [Value: %tile_size = arith.constant 8 : index] --> [Attribute: 8]
1165- auto [_, tileSize] =
1166- getSimplifiedOfrAndStaticSizePair (dimAndTileMapping[i], rewriter);
1167- tileSizes.push_back (tileSize);
1168- transShapeForEmptyOp[i] = tileSize;
1169- }
1170- }
1154+ // 1. Get the input that is going to be packed. If the input requires padding,
1155+ // add a padding operation and return that as the input.
1156+ Value input = getPackOpSourceOrPaddedSource (rewriter, packOp);
11711157
11721158 // 2. Transpose the input to match the inner tile order:
11731159 // %init = tensor.empty()
11741160 // %transposed_tile = linalg.transpose ins(%source_or_padded_source),
11751161 // outs(%init)
11761162 // Assumptions made:
1177- // 1. All outer dims are 1 - the corresponding transposition order doesn't
1163+ // - All outer dims are 1 - the corresponding transposition order doesn't
11781164 // matter, but requires all dim indices to be present.
1165+
1166+ // 2.1 Get the permutation for linalg.transpose
11791167 SmallVector<int64_t > srcPermForTranspose;
1180- ArrayRef<int64_t > innerDimPos (packOp.getInnerDimsPos ());
11811168 for (int64_t i = 0 ; i < srcRank; i++) {
11821169 // We assume the `k` dimensions of the inner dim position, where `k` is the
11831170 // rank of the inner tiling, correspond to the last `k` indices of the
@@ -1186,21 +1173,32 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
11861173 // rank of the source tensor. For example if we have a source tensor with
11871174 // indices [0, 1, 2, 3] and inner dim position of [3, 0], the remaining
11881175 // indices are [1, 2]. and the transpose will be [1, 2, 3, 0].
1189- if (llvm::is_contained (innerDimPos , i))
1176+ if (llvm::is_contained (innerDimsPos , i))
11901177 continue ;
11911178 srcPermForTranspose.push_back (i);
11921179 }
1193- srcPermForTranspose.append (innerDimPos.begin (), innerDimPos.end ());
1180+ srcPermForTranspose.append (innerDimsPos.begin (), innerDimsPos.end ());
1181+
1182+ // 2.2 Create the init tensor for linalg.transpose with the correct shape
1183+ SmallVector<OpFoldResult> shapeForEmptyOp (srcRank - numberOfTiles,
1184+ oneIdxAttr);
1185+ shapeForEmptyOp.append (packOp.getMixedTiles ());
1186+
1187+ // getMixedTiles() may contain Values pointing to constant ops, not the
1188+ // constant attributes. Replace them with a true OpFoldResult.
1189+ llvm::transform (shapeForEmptyOp, shapeForEmptyOp.begin (),
1190+ [&](OpFoldResult ofr) {
1191+ if (auto val = llvm::dyn_cast<Value>(ofr))
1192+ return getAsOpFoldResult (val);
1193+ return ofr;
1194+ });
11941195
11951196 LDBG () << " Pack permutation: " << packOp;
11961197 LDBG () << " perm: " << llvm::interleaved (srcPermForTranspose);
1198+ LDBG () << " Shape of empty tensor: " << llvm::interleaved (shapeForEmptyOp);
11971199
1198- // 2.2 Transpose the tensor.empty shapes.
1199- applyPermutationToVector<OpFoldResult>(transShapeForEmptyOp,
1200- srcPermForTranspose);
1201- Value empty =
1202- tensor::EmptyOp::create (rewriter, loc, transShapeForEmptyOp,
1203- packOp.getSourceType ().getElementType ());
1200+ Value empty = tensor::EmptyOp::create (
1201+ rewriter, loc, shapeForEmptyOp, packOp.getSourceType ().getElementType ());
12041202
12051203 // 2.3 Create linalg.transpose
12061204 auto transposedOp = linalg::TransposeOp::create (rewriter, loc, input, empty,
@@ -1211,8 +1209,7 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
12111209 SmallVector<OpFoldResult> writeStrides (destRank, oneIdxAttr);
12121210 SmallVector<OpFoldResult> writeOffsets (destRank, zeroIdxAttr);
12131211 // Outer dims are all 1s!
1214- SmallVector<OpFoldResult> writeSizes (destRank - dimAndTileMapping.size (),
1215- oneIdxAttr);
1212+ SmallVector<OpFoldResult> writeSizes (destRank - numberOfTiles, oneIdxAttr);
12161213 SmallVector<int64_t > writeShape;
12171214
12181215 for (auto tileSize : packOp.getMixedTiles ()) {
0 commit comments