@@ -1167,12 +1167,9 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
11671167 " this is not supported ATM!" );
11681168 }
11691169
1170- Attribute zeroIdxAttr = rewriter.getIndexAttr (0 );
1171- Attribute oneIdxAttr = rewriter.getIndexAttr (1 );
11721170 Location loc = packOp.getLoc ();
11731171
11741172 int64_t srcRank = packOp.getSourceRank ();
1175- int64_t destRank = packOp.getDestRank ();
11761173
11771174 // 1. Get the input that is going to be packed. If the input requires padding,
11781175 // add a padding operation and return that as the input.
@@ -1262,14 +1259,8 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
12621259 writeSizes.push_back (tileSizeOfr);
12631260 }
12641261
1265- // TODO: Add a constructor for tensor.insert_slice that doesn't require
1266- // strides nor offsets.
1267- SmallVector<OpFoldResult> writeStrides (destRank, oneIdxAttr);
1268- SmallVector<OpFoldResult> writeOffsets (destRank, zeroIdxAttr);
1269-
12701262 auto insert = tensor::InsertSliceOp::create (
1271- rewriter, loc, transposedOp.getResult ()[0 ], packOp.getDest (),
1272- writeOffsets, writeSizes, writeStrides);
1263+ rewriter, loc, transposedOp.getResult ()[0 ], packOp.getDest (), writeSizes);
12731264
12741265 // 4. Replace tensor.packOp with tensor.insert_slice created above
12751266 rewriter.replaceOp (packOp, insert.getResult ());
@@ -1279,7 +1270,6 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
12791270
12801271LogicalResult DecomposeOuterUnitDimsUnPackOpPattern::matchAndRewrite (
12811272 linalg::UnPackOp unpackOp, PatternRewriter &rewriter) const {
1282- int64_t srcRank = unpackOp.getSourceRank ();
12831273 int64_t destRank = unpackOp.getDestRank ();
12841274 ArrayRef<int64_t > srcShape = unpackOp.getSourceType ().getShape ();
12851275 ArrayRef<int64_t > innerDimsPos = unpackOp.getInnerDimsPos ();
@@ -1296,7 +1286,6 @@ LogicalResult DecomposeOuterUnitDimsUnPackOpPattern::matchAndRewrite(
12961286 Value source = unpackOp.getSource ();
12971287 DenseMap<int64_t , OpFoldResult> dimAndTileMapping =
12981288 unpackOp.getDimAndTileMapping ();
1299- Attribute zeroIdxAttr = rewriter.getIndexAttr (0 );
13001289 Attribute oneIdxAttr = rewriter.getIndexAttr (1 );
13011290
13021291 // The shape for ExtractSliceOp. Note that this will consist of 3 blocks of
@@ -1307,9 +1296,6 @@ LogicalResult DecomposeOuterUnitDimsUnPackOpPattern::matchAndRewrite(
13071296 // outer-tiled-dims being all 1), this will be
13081297 // [ outer-untiled-dims, tile-sizes ]
13091298 SmallVector<OpFoldResult> extractSliceSizes;
1310- // The offset and strides attributes for ExtractSliceOp.
1311- SmallVector<OpFoldResult> extractSliceOffsets (srcRank, zeroIdxAttr);
1312- SmallVector<OpFoldResult> extractSliceStrides (srcRank, oneIdxAttr);
13131299
13141300 // Shape for EmptyOp that's used as the init value for TransposeOp below.
13151301 // This should be:
@@ -1364,8 +1350,7 @@ LogicalResult DecomposeOuterUnitDimsUnPackOpPattern::matchAndRewrite(
13641350 Type elemType = unpackOp.getSourceType ().getElementType ();
13651351 auto readType = RankedTensorType::get (readShapeForExtractSlice, elemType);
13661352 Value innerTile = tensor::ExtractSliceOp::create (
1367- rewriter, loc, readType, unpackOp.getSource (), extractSliceOffsets,
1368- extractSliceSizes, extractSliceStrides);
1353+ rewriter, loc, readType, unpackOp.getSource (), extractSliceSizes);
13691354
13701355 // 2. Transpose the tile to match the outer corresponding tile order.
13711356 SmallVector<int64_t > perm = getPackUnpackRankReducedPerm (
@@ -1381,9 +1366,6 @@ LogicalResult DecomposeOuterUnitDimsUnPackOpPattern::matchAndRewrite(
13811366
13821367 // 3. Handle in-complete tiles if needed. It truncates trailing data from the
13831368 // transposed tile.
1384- int numLoops = shapeForEmptyOp.size ();
1385- SmallVector<OpFoldResult> tileStrides (numLoops, oneIdxAttr);
1386- SmallVector<OpFoldResult> tileOffsets (numLoops, zeroIdxAttr);
13871369 SmallVector<OpFoldResult> tileSizes;
13881370 ArrayRef<int64_t > destShape = unpackOp.getDestType ().getShape ();
13891371 for (auto i : llvm::seq<unsigned >(0 , destRank)) {
@@ -1393,22 +1375,19 @@ LogicalResult DecomposeOuterUnitDimsUnPackOpPattern::matchAndRewrite(
13931375 }
13941376
13951377 auto partialTile =
1396- tensor::ExtractSliceOp::create (rewriter, loc, transposedOp. getResult ()[ 0 ] ,
1397- tileOffsets , tileSizes, tileStrides );
1378+ tensor::ExtractSliceOp::create (rewriter, loc, RankedTensorType () ,
1379+ transposedOp. getResult ()[ 0 ] , tileSizes);
13981380
13991381 // 4. Insert the result to the destination tensor.
14001382 SmallVector<OpFoldResult> writeSizes;
1401- SmallVector<OpFoldResult> writeStrides (destRank, oneIdxAttr);
1402- SmallVector<OpFoldResult> writeOffsets (destRank, zeroIdxAttr);
14031383 for (int i = 0 , idx = 0 ; i < destRank; ++i) {
14041384 if (dimAndTileMapping.count (i) || destShape[i] != 1 )
14051385 writeSizes.push_back (tileSizes[idx++]);
14061386 else
14071387 writeSizes.push_back (oneIdxAttr);
14081388 }
14091389 auto insert = tensor::InsertSliceOp::create (rewriter, loc, partialTile,
1410- unpackOp.getDest (), writeOffsets,
1411- writeSizes, writeStrides);
1390+ unpackOp.getDest (), writeSizes);
14121391 rewriter.replaceOp (unpackOp, insert.getResult ());
14131392
14141393 return success ();
0 commit comments