@@ -1254,64 +1254,98 @@ LogicalResult DecomposeOuterUnitDimsUnPackOpPattern::matchAndRewrite(
12541254 " require the tiled outer dimensions of the result are all 1s" );
12551255 }
12561256
1257- // 1. Use rank-reduced tensor.extract_slice op to extract the tile.
1257+ // 1. Use rank-reduced tensor.extract_slice op to extract the tile:
1258+ // %extracted_tile = tensor.extract_slice(%unpack_op_input)
12581259 Location loc = unpackOp.getLoc ();
12591260 Value source = unpackOp.getSource ();
12601261 DenseMap<int64_t , OpFoldResult> dimAndTileMapping =
12611262 unpackOp.getDimAndTileMapping ();
12621263 Attribute zeroIdxAttr = rewriter.getIndexAttr (0 );
12631264 Attribute oneIdxAttr = rewriter.getIndexAttr (1 );
1264- SmallVector<OpFoldResult> readOffsets (srcRank, zeroIdxAttr);
1265- SmallVector<OpFoldResult> readStrides (srcRank, oneIdxAttr);
1266- SmallVector<OpFoldResult> readSizes;
1267- SmallVector<int64_t > readShape;
1268- SmallVector<Value> dynamicDims;
1265+
1266+ // The shape for ExtractSliceOp. Note that this will consist of 3 blocks of
1267+ // dims:
1268+ // [ outer-untiled-dims, outer-tiled-dims, tile-sizes ]
1269+ SmallVector<int64_t > readShapeForExtractSlice;
1270+ // The sizes attribute for ExtractSliceOp. Due to rank-reducing (and
1271+ // outer-tiled-dims being all 1), this will be
1272+ // [ outer-untiled-dims, tile-sizes ]
1273+ SmallVector<OpFoldResult> extractSliceSizes;
1274+ // The offset and strides attributes for ExtractSliceOp.
1275+ SmallVector<OpFoldResult> extractSliceOffsets (srcRank, zeroIdxAttr);
1276+ SmallVector<OpFoldResult> extractSliceStrides (srcRank, oneIdxAttr);
1277+
1278+ // Shape for EmptyOp that's used as the init value for TransposeOp below.
1279+ // This should be:
1280+ // [ outer-untiled-dims, tile-sizes ]
1281+ // However, skip unit dims - TransposeOp (below) applies rank-reduced
1282+ // permutation.
1283+ SmallVector<OpFoldResult> shapeForEmptyOp;
1284+
12691285 for (auto i : llvm::seq<unsigned >(0 , destRank)) {
1286+ // Compute sizes attribute for ExtractSliceOp - outer-tiled-dims.
1287+ //
1288+ // As all outer tiled dims are 1, so the corresponding
1289+ // slice size to read will also 1. As this will be rank-reducing "extract
1290+ // slice" (i.e. the unit dims will be "collapsed"), there's no need to
1291+ // update:
1292+ // * the output shape for ExtractSliceOp, nor
1293+ // * the shape for EmptyOp.
12701294 if (dimAndTileMapping.count (i)) {
1271- readSizes .push_back (oneIdxAttr);
1295+ extractSliceSizes .push_back (oneIdxAttr);
12721296 continue ;
12731297 }
12741298
1299+ // Compute sizes attribute for ExtractSliceOp + EmptyOp -
1300+ // outer-untiled-dims
12751301 if (ShapedType::isDynamic (srcShape[i])) {
1276- Value dynamicDim =
1302+ OpFoldResult dynamicDim =
12771303 rewriter.create <tensor::DimOp>(loc, source, i).getResult ();
1278- readSizes .push_back (dynamicDim);
1279- dynamicDims .push_back (dynamicDim);
1304+ extractSliceSizes .push_back (dynamicDim);
1305+ shapeForEmptyOp .push_back (dynamicDim);
12801306 } else {
1281- readSizes.push_back (rewriter.getIndexAttr (srcShape[i]));
1307+ extractSliceSizes.push_back (rewriter.getIndexAttr (srcShape[i]));
1308+ if (srcShape[i] != 1 )
1309+ shapeForEmptyOp.push_back (rewriter.getIndexAttr (srcShape[i]));
1310+ }
1311+ // Compute the output shape for ExtractSliceOp - outer-untiled-dims (take
1312+ // into account rank-reducing)
1313+ if (srcShape[i] != 1 ) {
1314+ readShapeForExtractSlice.push_back (srcShape[i]);
12821315 }
1283- if (srcShape[i] != 1 )
1284- readShape.push_back (srcShape[i]);
12851316 }
1317+ // Append the tile sizes to "sizes attribute" for ExtractSliceOp and the
1318+ // shape for EmptyOp.
12861319 auto mixedTiles = unpackOp.getMixedTiles ();
1287- readSizes.append (mixedTiles.begin (), mixedTiles.end ());
1320+ extractSliceSizes.append (mixedTiles.begin (), mixedTiles.end ());
1321+ shapeForEmptyOp.append (mixedTiles.begin (), mixedTiles.end ());
12881322
12891323 // Explicitly create the type for extract_slice op because the inner tile
12901324 // size could be 1. We want to represent the whole inner tile in this case.
12911325 auto tileShape = srcShape.drop_front (destRank);
12921326 // Append the inner tile shape to the permuted and rank-reduced outer shape.
1293- readShape .append (tileShape.begin (), tileShape.end ());
1327+ readShapeForExtractSlice .append (tileShape.begin (), tileShape.end ());
12941328 Type elemType = unpackOp.getSourceType ().getElementType ();
1295- auto readType = RankedTensorType::get (readShape , elemType);
1329+ auto readType = RankedTensorType::get (readShapeForExtractSlice , elemType);
12961330 Value innerTile = rewriter.create <tensor::ExtractSliceOp>(
1297- loc, readType, unpackOp.getSource (), readOffsets, readSizes, readStrides);
1331+ loc, readType, unpackOp.getSource (), extractSliceOffsets,
1332+ extractSliceSizes, extractSliceStrides);
12981333
12991334 // 2. Transpose the tile to match the outer corresponding tile order.
13001335 SmallVector<int64_t > perm = getPackUnpackRankReducedPerm (
13011336 srcShape.take_front (destRank), innerDimsPos, unpackOp.getOuterDimsPerm ());
13021337 // Unpack is a transition out of packed space so we invert the permutation.
13031338 perm = invertPermutationVector (perm);
1304- SmallVector<int64_t > transpShape (readShape);
1305- applyPermutationToVector<int64_t >(transpShape, perm);
1339+ applyPermutationToVector<OpFoldResult>(shapeForEmptyOp, perm);
13061340
13071341 Value empty =
1308- rewriter.create <tensor::EmptyOp>(loc, transpShape , elemType, dynamicDims );
1342+ rewriter.create <tensor::EmptyOp>(loc, shapeForEmptyOp , elemType);
13091343 auto transposedOp =
13101344 rewriter.create <linalg::TransposeOp>(loc, innerTile, empty, perm);
13111345
13121346 // 3. Handle in-complete tiles if needed. It truncates trailing data from the
13131347 // transposed tile.
1314- int numLoops = transpShape .size ();
1348+ int numLoops = shapeForEmptyOp .size ();
13151349 SmallVector<OpFoldResult> tileStrides (numLoops, oneIdxAttr);
13161350 SmallVector<OpFoldResult> tileOffsets (numLoops, zeroIdxAttr);
13171351 SmallVector<OpFoldResult> tileSizes;
0 commit comments