From dd275ad51edbbdcc2a78c7114560e2d9ac42b3ca Mon Sep 17 00:00:00 2001 From: Andrzej Warzynski Date: Tue, 25 Nov 2025 17:30:54 +0000 Subject: [PATCH] [mlir][tensor] Add new builders for insert_slice/extract_slice Ops (nfc) Adds new builders for `tensor.insert_slice` and `tensor.extract_slice` Ops for which the _offsets_ and the _strides_ are all 0s and 1s, respecitvely. This allows us to write: ```cpp tensor::InsertSliceOp::create(rewriter, loc, src, dest, writeSizes); ``` instead of: ```cpp Attribute oneIdxAttr = rewriter.getIndexAttr(1); SmallVector writeStrides(destRank, oneIdxAttr); Attribute zeroIdxAttr = rewriter.getIndexAttr(0); SmallVector writeOffsets(destRank, zeroIdxAttr); tensor::InsertSliceOp::create(rewriter, loc, src, dest, writeOffsets, writeSizes, writeStrides); ``` --- .../mlir/Dialect/Tensor/IR/TensorOps.td | 12 ++++++- .../Dialect/Linalg/Transforms/Transforms.cpp | 31 +++---------------- mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 25 +++++++++++++++ 3 files changed, 41 insertions(+), 27 deletions(-) diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td index ac40d5e454281..35d2b6007c628 100644 --- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td +++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td @@ -471,6 +471,11 @@ def Tensor_ExtractSliceOp : Tensor_OpWithOffsetSizesAndStrides<"extract_slice", // a Range vector. OpBuilder<(ins "Value":$source, "ArrayRef":$ranges, CArg<"ArrayRef", "{}">:$attrs)>, + // Build an ExtractSliceOp with mixed static and dynamic sizes, inferred + // result type, offsets set to 0 and strides set to 1. + OpBuilder<(ins "RankedTensorType":$resultType, "Value":$source, + "ArrayRef":$sizes, + CArg<"ArrayRef", "{}">:$attrs)>, ]; let extraClassDeclaration = extraBaseClassDeclaration # [{ @@ -930,7 +935,12 @@ def Tensor_InsertSliceOp : Tensor_OpWithOffsetSizesAndStrides<"insert_slice", [ // a Range vector and inferred result type. OpBuilder<(ins "Value":$source, "Value":$dest, "ArrayRef":$ranges, - CArg<"ArrayRef", "{}">:$attrs)> + CArg<"ArrayRef", "{}">:$attrs)>, + // Build an InsertSliceOp with mixed static and dynamic sizes, offsets set + // to 0, strides set to 1 and inferred result type. + OpBuilder<(ins "Value":$source, "Value":$dest, + "ArrayRef":$sizes, + CArg<"ArrayRef", "{}">:$attrs)>, ]; let extraClassDeclaration = extraBaseClassDeclaration # [{ diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp index 027268cc20ddd..67e2b9f8d6077 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -1167,12 +1167,9 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite( "this is not supported ATM!"); } - Attribute zeroIdxAttr = rewriter.getIndexAttr(0); - Attribute oneIdxAttr = rewriter.getIndexAttr(1); Location loc = packOp.getLoc(); int64_t srcRank = packOp.getSourceRank(); - int64_t destRank = packOp.getDestRank(); // 1. Get the input that is going to be packed. If the input requires padding, // add a padding operation and return that as the input. @@ -1262,14 +1259,8 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite( writeSizes.push_back(tileSizeOfr); } - // TODO: Add a constructor for tensor.insert_slice that doesn't require - // strides nor offsets. - SmallVector writeStrides(destRank, oneIdxAttr); - SmallVector writeOffsets(destRank, zeroIdxAttr); - auto insert = tensor::InsertSliceOp::create( - rewriter, loc, transposedOp.getResult()[0], packOp.getDest(), - writeOffsets, writeSizes, writeStrides); + rewriter, loc, transposedOp.getResult()[0], packOp.getDest(), writeSizes); // 4. Replace tensor.packOp with tensor.insert_slice created above rewriter.replaceOp(packOp, insert.getResult()); @@ -1279,7 +1270,6 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite( LogicalResult DecomposeOuterUnitDimsUnPackOpPattern::matchAndRewrite( linalg::UnPackOp unpackOp, PatternRewriter &rewriter) const { - int64_t srcRank = unpackOp.getSourceRank(); int64_t destRank = unpackOp.getDestRank(); ArrayRef srcShape = unpackOp.getSourceType().getShape(); ArrayRef innerDimsPos = unpackOp.getInnerDimsPos(); @@ -1296,7 +1286,6 @@ LogicalResult DecomposeOuterUnitDimsUnPackOpPattern::matchAndRewrite( Value source = unpackOp.getSource(); DenseMap dimAndTileMapping = unpackOp.getDimAndTileMapping(); - Attribute zeroIdxAttr = rewriter.getIndexAttr(0); Attribute oneIdxAttr = rewriter.getIndexAttr(1); // The shape for ExtractSliceOp. Note that this will consist of 3 blocks of @@ -1307,9 +1296,6 @@ LogicalResult DecomposeOuterUnitDimsUnPackOpPattern::matchAndRewrite( // outer-tiled-dims being all 1), this will be // [ outer-untiled-dims, tile-sizes ] SmallVector extractSliceSizes; - // The offset and strides attributes for ExtractSliceOp. - SmallVector extractSliceOffsets(srcRank, zeroIdxAttr); - SmallVector extractSliceStrides(srcRank, oneIdxAttr); // Shape for EmptyOp that's used as the init value for TransposeOp below. // This should be: @@ -1364,8 +1350,7 @@ LogicalResult DecomposeOuterUnitDimsUnPackOpPattern::matchAndRewrite( Type elemType = unpackOp.getSourceType().getElementType(); auto readType = RankedTensorType::get(readShapeForExtractSlice, elemType); Value innerTile = tensor::ExtractSliceOp::create( - rewriter, loc, readType, unpackOp.getSource(), extractSliceOffsets, - extractSliceSizes, extractSliceStrides); + rewriter, loc, readType, unpackOp.getSource(), extractSliceSizes); // 2. Transpose the tile to match the outer corresponding tile order. SmallVector perm = getPackUnpackRankReducedPerm( @@ -1381,9 +1366,6 @@ LogicalResult DecomposeOuterUnitDimsUnPackOpPattern::matchAndRewrite( // 3. Handle in-complete tiles if needed. It truncates trailing data from the // transposed tile. - int numLoops = shapeForEmptyOp.size(); - SmallVector tileStrides(numLoops, oneIdxAttr); - SmallVector tileOffsets(numLoops, zeroIdxAttr); SmallVector tileSizes; ArrayRef destShape = unpackOp.getDestType().getShape(); for (auto i : llvm::seq(0, destRank)) { @@ -1393,13 +1375,11 @@ LogicalResult DecomposeOuterUnitDimsUnPackOpPattern::matchAndRewrite( } auto partialTile = - tensor::ExtractSliceOp::create(rewriter, loc, transposedOp.getResult()[0], - tileOffsets, tileSizes, tileStrides); + tensor::ExtractSliceOp::create(rewriter, loc, RankedTensorType(), + transposedOp.getResult()[0], tileSizes); // 4. Insert the result to the destination tensor. SmallVector writeSizes; - SmallVector writeStrides(destRank, oneIdxAttr); - SmallVector writeOffsets(destRank, zeroIdxAttr); for (int i = 0, idx = 0; i < destRank; ++i) { if (dimAndTileMapping.count(i) || destShape[i] != 1) writeSizes.push_back(tileSizes[idx++]); @@ -1407,8 +1387,7 @@ LogicalResult DecomposeOuterUnitDimsUnPackOpPattern::matchAndRewrite( writeSizes.push_back(oneIdxAttr); } auto insert = tensor::InsertSliceOp::create(rewriter, loc, partialTile, - unpackOp.getDest(), writeOffsets, - writeSizes, writeStrides); + unpackOp.getDest(), writeSizes); rewriter.replaceOp(unpackOp, insert.getResult()); return success(); diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp index 5a58d7cbed30f..204e9bb73e12c 100644 --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -2445,6 +2445,19 @@ static LogicalResult produceSliceErrorMsg(SliceVerificationResult result, } } +/// Build an ExtractSliceOp with mixed static and dynamic sizes, inferred +/// result type, offsets set to 0 and strides set to 1. +void ExtractSliceOp::build(OpBuilder &b, OperationState &result, + RankedTensorType resultType, Value source, + ArrayRef sizes, + ArrayRef attrs) { + Attribute zeroIdxAttr = b.getIndexAttr(0); + Attribute oneIdxAttr = b.getIndexAttr(1); + SmallVector readStrides(sizes.size(), oneIdxAttr); + SmallVector readOffsets(sizes.size(), zeroIdxAttr); + build(b, result, resultType, source, readOffsets, sizes, readStrides, attrs); +} + /// Verifier for ExtractSliceOp. LogicalResult ExtractSliceOp::verify() { RankedTensorType sourceType = getSourceType(); @@ -3889,6 +3902,18 @@ void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result, build(b, result, source, dest, offsetValues, sizeValues, strideValues); } +// Build an InsertSliceOp with mixed static and dynamic sizes, offsets set +// to 0, strides set to 1 and inferred result type. +void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source, + Value dest, ArrayRef sizes, + ArrayRef attrs) { + Attribute zeroIdxAttr = b.getIndexAttr(0); + Attribute oneIdxAttr = b.getIndexAttr(1); + SmallVector writeStrides(sizes.size(), oneIdxAttr); + SmallVector writeOffsets(sizes.size(), zeroIdxAttr); + build(b, result, source, dest, writeOffsets, sizes, writeStrides, attrs); +} + LogicalResult ParallelInsertSliceOp::verify() { if (!isa(getOperation()->getParentOp())) return this->emitError("expected InParallelOpInterface parent, got:")