@@ -1261,21 +1261,30 @@ LogicalResult DecomposeOuterUnitDimsUnPackOpPattern::matchAndRewrite(
12611261 Attribute zeroIdxAttr = rewriter.getIndexAttr (0 );
12621262 Attribute oneIdxAttr = rewriter.getIndexAttr (1 );
12631263
1264- // The sizes, affset and strides attributes for ExtractSliceOp.
1264+ // The shape for ExtractSliceOp. Note that this will consist of 3 blocks of
1265+ // dims:
1266+ // [ outer-untiled-dims, outer-tiled-dims, tile-sizes ]
1267+ SmallVector<int64_t > readShapeForExtractSlice;
1268+ // The sizes attribute for ExtractSliceOp. Due to rank-reducing (and
1269+ // outer-tiled-dims being all 1), this will be
1270+ // [ outer-untiled-dims, tile-sizes ]
12651271 SmallVector<OpFoldResult> extractSliceSizes;
1272+ // The offset and strides attributes for ExtractSliceOp.
12661273 SmallVector<OpFoldResult> extractSliceOffsets (srcRank, zeroIdxAttr);
12671274 SmallVector<OpFoldResult> extractSliceStrides (srcRank, oneIdxAttr);
1268- // The shape for ExtractSliceOp (due to rank-reducing, this is likely !=
1269- // extractSliceSizes).
1270- SmallVector<int64_t > readShapeForExtractSlice;
12711275
12721276 // Shape for EmptyOp that's used as the init value for TransposeOp below.
1273- // This should match tile size + transposition.
1277+ // This should be:
1278+ // [ outer-untiled-dims, tile-sizes ]
1279+ // However, skip unit dims - TransposeOp (below) applies rank-reduced
1280+ // permutation.
12741281 SmallVector<OpFoldResult> shapeForEmptyOp;
12751282
12761283 for (auto i : llvm::seq<unsigned >(0 , destRank)) {
1277- // Given the assumption that all outer tiled dims are 1, the corresponding
1278- // slice size to read is also 1. As this will be rank-reducing "extract
1284+ // Compute sizes attribute for ExtractSliceOp - outer-tiled-dims.
1285+ //
1286+ // As all outer tiled dims are 1, so the corresponding
1287+ // slice size to read will also 1. As this will be rank-reducing "extract
12791288 // slice" (i.e. the unit dims will be "collapsed"), there's no need to
12801289 // update:
12811290 // * the output shape for ExtractSliceOp, nor
@@ -1285,7 +1294,8 @@ LogicalResult DecomposeOuterUnitDimsUnPackOpPattern::matchAndRewrite(
12851294 continue ;
12861295 }
12871296
1288- // Compute sizes attribute for ExtractSliceOp + EmptyOp
1297+ // Compute sizes attribute for ExtractSliceOp + EmptyOp -
1298+ // outer-untiled-dims
12891299 if (ShapedType::isDynamic (srcShape[i])) {
12901300 OpFoldResult dynamicDim =
12911301 rewriter.create <tensor::DimOp>(loc, source, i).getResult ();
@@ -1296,15 +1306,15 @@ LogicalResult DecomposeOuterUnitDimsUnPackOpPattern::matchAndRewrite(
12961306 if (srcShape[i] != 1 )
12971307 shapeForEmptyOp.push_back (rewriter.getIndexAttr (srcShape[i]));
12981308 }
1299- // Compute the output shape for ExtractSliceOp (take into account
1300- // rank-reducing)
1309+ // Compute the output shape for ExtractSliceOp - outer-untiled-dims (take
1310+ // into account rank-reducing)
13011311 if (srcShape[i] != 1 ) {
13021312 readShapeForExtractSlice.push_back (srcShape[i]);
13031313 }
13041314 }
1315+ // Append the tile sizes to "sizes attribute" for ExtractSliceOp and the
1316+ // shape for EmptyOp.
13051317 auto mixedTiles = unpackOp.getMixedTiles ();
1306- // TODO: This effectively assumes that that tile sizes match the trailing
1307- // sizes for ExtractSliceOp and EmptyOp - document this.
13081318 extractSliceSizes.append (mixedTiles.begin (), mixedTiles.end ());
13091319 shapeForEmptyOp.append (mixedTiles.begin (), mixedTiles.end ());
13101320
0 commit comments