@@ -1146,37 +1146,25 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
1146
1146
Attribute oneIdxAttr = rewriter.getIndexAttr (1 );
1147
1147
Location loc = packOp.getLoc ();
1148
1148
1149
- Value input = getPackOpSourceOrPaddedSource (rewriter, packOp);
1150
- DenseMap<int64_t , OpFoldResult> dimAndTileMapping =
1151
- packOp.getDimAndTileMapping ();
1152
1149
int64_t srcRank = packOp.getSourceRank ();
1153
1150
int64_t destRank = packOp.getDestRank ();
1154
- int64_t numTiles = destRank - srcRank;
1151
+ ArrayRef<int64_t > innerDimsPos = packOp.getInnerDimsPos ();
1152
+ int64_t numberOfTiles = innerDimsPos.size ();
1155
1153
1156
- // 1. Extract the inner tile sizes.
1157
- // Where possible, values are replaced with constant attributes (to match the
1158
- // behaviour of `getPackOpSourceOrPaddedSource`).
1159
- SmallVector<OpFoldResult> tileSizes;
1160
- for (auto i : llvm::seq<unsigned >(0 , srcRank)) {
1161
- if (dimAndTileMapping.count (i)) {
1162
- // Rather than taking the tile size as is, extact the actual constant
1163
- // value Attribute where possible, e.g.:
1164
- // [Value: %tile_size = arith.constant 8 : index] --> [Attribute: 8]
1165
- auto [_, tileSize] =
1166
- getSimplifiedOfrAndStaticSizePair (dimAndTileMapping[i], rewriter);
1167
- tileSizes.push_back (tileSize);
1168
- }
1169
- }
1154
+ // 1. Get the input that is going to be packed. If the input requires padding,
1155
+ // add a padding operation and return that as the input.
1156
+ Value input = getPackOpSourceOrPaddedSource (rewriter, packOp);
1170
1157
1171
1158
// 2. Transpose the input to match the inner tile order:
1172
1159
// %init = tensor.empty()
1173
1160
// %transposed_tile = linalg.transpose ins(%source_or_padded_source),
1174
1161
// outs(%init)
1175
1162
// Assumptions made:
1176
- // 1. All outer dims are 1 - the corresponding transposition order doesn't
1163
+ // - All outer dims are 1 - the corresponding transposition order doesn't
1177
1164
// matter, but requires all dim indices to be present.
1165
+
1166
+ // 2.1 Get the permutation for linalg.transpose
1178
1167
SmallVector<int64_t > srcPermForTranspose;
1179
- ArrayRef<int64_t > innerDimPos (packOp.getInnerDimsPos ());
1180
1168
for (int64_t i = 0 ; i < srcRank; i++) {
1181
1169
// We assume the `k` dimensions of the inner dim position, where `k` is the
1182
1170
// rank of the inner tiling, correspond to the last `k` indices of the
@@ -1185,27 +1173,34 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
1185
1173
// rank of the source tensor. For example if we have a source tensor with
1186
1174
// indices [0, 1, 2, 3] and inner dim position of [3, 0], the remaining
1187
1175
// indices are [1, 2]. and the transpose will be [1, 2, 3, 0].
1188
- if (llvm::is_contained (innerDimPos , i))
1176
+ if (llvm::is_contained (innerDimsPos , i))
1189
1177
continue ;
1190
1178
srcPermForTranspose.push_back (i);
1191
1179
}
1192
- srcPermForTranspose.append (innerDimPos.begin (), innerDimPos.end ());
1180
+ srcPermForTranspose.append (innerDimsPos.begin (), innerDimsPos.end ());
1181
+
1182
+ // 2.2 Create the init tensor for linalg.transpose with the correct shape
1183
+ SmallVector<OpFoldResult> shapeForEmptyOp (srcRank - numberOfTiles,
1184
+ oneIdxAttr);
1185
+ shapeForEmptyOp.append (packOp.getMixedTiles ());
1186
+
1187
+ // getMixedTiles() may contain Values pointing to constant ops, not the
1188
+ // constant attributes. Replace them with a true OpFoldResult.
1189
+ llvm::transform (shapeForEmptyOp, shapeForEmptyOp.begin (),
1190
+ [&](OpFoldResult ofr) {
1191
+ if (auto val = llvm::dyn_cast<Value>(ofr))
1192
+ return getAsOpFoldResult (val);
1193
+ return ofr;
1194
+ });
1193
1195
1194
1196
LDBG () << " Pack permutation: " << packOp;
1195
1197
LDBG () << " perm: " << llvm::interleaved (srcPermForTranspose);
1198
+ LDBG () << " Shape of empty tensor: " << llvm::interleaved (shapeForEmptyOp);
1196
1199
1197
- // 2.1 Create tensor.empty (init value for TransposeOp)
1198
- SmallVector<OpFoldResult> transShapeForEmptyOp (srcRank - numTiles,
1199
- oneIdxAttr);
1200
- transShapeForEmptyOp.append (tileSizes);
1201
-
1202
- applyPermutationToVector<OpFoldResult>(transShapeForEmptyOp,
1203
- srcPermForTranspose);
1204
- Value empty =
1205
- tensor::EmptyOp::create (rewriter, loc, transShapeForEmptyOp,
1206
- packOp.getSourceType ().getElementType ());
1200
+ Value empty = tensor::EmptyOp::create (
1201
+ rewriter, loc, shapeForEmptyOp, packOp.getSourceType ().getElementType ());
1207
1202
1208
- // 2.2 Create linalg.transpose
1203
+ // 2.3 Create linalg.transpose
1209
1204
auto transposedOp = linalg::TransposeOp::create (rewriter, loc, input, empty,
1210
1205
srcPermForTranspose);
1211
1206
@@ -1214,8 +1209,7 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
1214
1209
SmallVector<OpFoldResult> writeStrides (destRank, oneIdxAttr);
1215
1210
SmallVector<OpFoldResult> writeOffsets (destRank, zeroIdxAttr);
1216
1211
// Outer dims are all 1s!
1217
- SmallVector<OpFoldResult> writeSizes (destRank - dimAndTileMapping.size (),
1218
- oneIdxAttr);
1212
+ SmallVector<OpFoldResult> writeSizes (destRank - numberOfTiles, oneIdxAttr);
1219
1213
SmallVector<int64_t > writeShape;
1220
1214
1221
1215
for (auto tileSize : packOp.getMixedTiles ()) {
0 commit comments