Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -471,6 +471,11 @@ def Tensor_ExtractSliceOp : Tensor_OpWithOffsetSizesAndStrides<"extract_slice",
// a Range vector.
OpBuilder<(ins "Value":$source, "ArrayRef<Range>":$ranges,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
// Build an ExtractSliceOp with mixed static and dynamic sizes, inferred
// result type, offsets set to 0 and strides set to 1.
OpBuilder<(ins "RankedTensorType":$resultType, "Value":$source,
"ArrayRef<OpFoldResult>":$sizes,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
];

let extraClassDeclaration = extraBaseClassDeclaration # [{
Expand Down Expand Up @@ -930,7 +935,12 @@ def Tensor_InsertSliceOp : Tensor_OpWithOffsetSizesAndStrides<"insert_slice", [
// a Range vector and inferred result type.
OpBuilder<(ins "Value":$source, "Value":$dest,
"ArrayRef<Range>":$ranges,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
// Build an InsertSliceOp with mixed static and dynamic sizes, offsets set
// to 0, strides set to 1 and inferred result type.
OpBuilder<(ins "Value":$source, "Value":$dest,
"ArrayRef<OpFoldResult>":$sizes,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
];

let extraClassDeclaration = extraBaseClassDeclaration # [{
Expand Down
31 changes: 5 additions & 26 deletions mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1167,12 +1167,9 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
"this is not supported ATM!");
}

Attribute zeroIdxAttr = rewriter.getIndexAttr(0);
Attribute oneIdxAttr = rewriter.getIndexAttr(1);
Location loc = packOp.getLoc();

int64_t srcRank = packOp.getSourceRank();
int64_t destRank = packOp.getDestRank();

// 1. Get the input that is going to be packed. If the input requires padding,
// add a padding operation and return that as the input.
Expand Down Expand Up @@ -1262,14 +1259,8 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
writeSizes.push_back(tileSizeOfr);
}

// TODO: Add a constructor for tensor.insert_slice that doesn't require
// strides nor offsets.
SmallVector<OpFoldResult> writeStrides(destRank, oneIdxAttr);
SmallVector<OpFoldResult> writeOffsets(destRank, zeroIdxAttr);

auto insert = tensor::InsertSliceOp::create(
rewriter, loc, transposedOp.getResult()[0], packOp.getDest(),
writeOffsets, writeSizes, writeStrides);
rewriter, loc, transposedOp.getResult()[0], packOp.getDest(), writeSizes);

// 4. Replace tensor.packOp with tensor.insert_slice created above
rewriter.replaceOp(packOp, insert.getResult());
Expand All @@ -1279,7 +1270,6 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(

LogicalResult DecomposeOuterUnitDimsUnPackOpPattern::matchAndRewrite(
linalg::UnPackOp unpackOp, PatternRewriter &rewriter) const {
int64_t srcRank = unpackOp.getSourceRank();
int64_t destRank = unpackOp.getDestRank();
ArrayRef<int64_t> srcShape = unpackOp.getSourceType().getShape();
ArrayRef<int64_t> innerDimsPos = unpackOp.getInnerDimsPos();
Expand All @@ -1296,7 +1286,6 @@ LogicalResult DecomposeOuterUnitDimsUnPackOpPattern::matchAndRewrite(
Value source = unpackOp.getSource();
DenseMap<int64_t, OpFoldResult> dimAndTileMapping =
unpackOp.getDimAndTileMapping();
Attribute zeroIdxAttr = rewriter.getIndexAttr(0);
Attribute oneIdxAttr = rewriter.getIndexAttr(1);

// The shape for ExtractSliceOp. Note that this will consist of 3 blocks of
Expand All @@ -1307,9 +1296,6 @@ LogicalResult DecomposeOuterUnitDimsUnPackOpPattern::matchAndRewrite(
// outer-tiled-dims being all 1), this will be
// [ outer-untiled-dims, tile-sizes ]
SmallVector<OpFoldResult> extractSliceSizes;
// The offset and strides attributes for ExtractSliceOp.
SmallVector<OpFoldResult> extractSliceOffsets(srcRank, zeroIdxAttr);
SmallVector<OpFoldResult> extractSliceStrides(srcRank, oneIdxAttr);

// Shape for EmptyOp that's used as the init value for TransposeOp below.
// This should be:
Expand Down Expand Up @@ -1364,8 +1350,7 @@ LogicalResult DecomposeOuterUnitDimsUnPackOpPattern::matchAndRewrite(
Type elemType = unpackOp.getSourceType().getElementType();
auto readType = RankedTensorType::get(readShapeForExtractSlice, elemType);
Value innerTile = tensor::ExtractSliceOp::create(
rewriter, loc, readType, unpackOp.getSource(), extractSliceOffsets,
extractSliceSizes, extractSliceStrides);
rewriter, loc, readType, unpackOp.getSource(), extractSliceSizes);

// 2. Transpose the tile to match the outer corresponding tile order.
SmallVector<int64_t> perm = getPackUnpackRankReducedPerm(
Expand All @@ -1381,9 +1366,6 @@ LogicalResult DecomposeOuterUnitDimsUnPackOpPattern::matchAndRewrite(

// 3. Handle in-complete tiles if needed. It truncates trailing data from the
// transposed tile.
int numLoops = shapeForEmptyOp.size();
SmallVector<OpFoldResult> tileStrides(numLoops, oneIdxAttr);
SmallVector<OpFoldResult> tileOffsets(numLoops, zeroIdxAttr);
SmallVector<OpFoldResult> tileSizes;
ArrayRef<int64_t> destShape = unpackOp.getDestType().getShape();
for (auto i : llvm::seq<unsigned>(0, destRank)) {
Expand All @@ -1393,22 +1375,19 @@ LogicalResult DecomposeOuterUnitDimsUnPackOpPattern::matchAndRewrite(
}

auto partialTile =
tensor::ExtractSliceOp::create(rewriter, loc, transposedOp.getResult()[0],
tileOffsets, tileSizes, tileStrides);
tensor::ExtractSliceOp::create(rewriter, loc, RankedTensorType(),
transposedOp.getResult()[0], tileSizes);

// 4. Insert the result to the destination tensor.
SmallVector<OpFoldResult> writeSizes;
SmallVector<OpFoldResult> writeStrides(destRank, oneIdxAttr);
SmallVector<OpFoldResult> writeOffsets(destRank, zeroIdxAttr);
for (int i = 0, idx = 0; i < destRank; ++i) {
if (dimAndTileMapping.count(i) || destShape[i] != 1)
writeSizes.push_back(tileSizes[idx++]);
else
writeSizes.push_back(oneIdxAttr);
}
auto insert = tensor::InsertSliceOp::create(rewriter, loc, partialTile,
unpackOp.getDest(), writeOffsets,
writeSizes, writeStrides);
unpackOp.getDest(), writeSizes);
rewriter.replaceOp(unpackOp, insert.getResult());

return success();
Expand Down
25 changes: 25 additions & 0 deletions mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2445,6 +2445,19 @@ static LogicalResult produceSliceErrorMsg(SliceVerificationResult result,
}
}

/// Build an ExtractSliceOp with mixed static and dynamic sizes, inferred
/// result type, offsets set to 0 and strides set to 1.
void ExtractSliceOp::build(OpBuilder &b, OperationState &result,
RankedTensorType resultType, Value source,
ArrayRef<OpFoldResult> sizes,
ArrayRef<NamedAttribute> attrs) {
Attribute zeroIdxAttr = b.getIndexAttr(0);
Attribute oneIdxAttr = b.getIndexAttr(1);
SmallVector<OpFoldResult> readStrides(sizes.size(), oneIdxAttr);
SmallVector<OpFoldResult> readOffsets(sizes.size(), zeroIdxAttr);
build(b, result, resultType, source, readOffsets, sizes, readStrides, attrs);
}

/// Verifier for ExtractSliceOp.
LogicalResult ExtractSliceOp::verify() {
RankedTensorType sourceType = getSourceType();
Expand Down Expand Up @@ -3889,6 +3902,18 @@ void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result,
build(b, result, source, dest, offsetValues, sizeValues, strideValues);
}

// Build an InsertSliceOp with mixed static and dynamic sizes, offsets set
// to 0, strides set to 1 and inferred result type.
void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
Value dest, ArrayRef<OpFoldResult> sizes,
ArrayRef<NamedAttribute> attrs) {
Attribute zeroIdxAttr = b.getIndexAttr(0);
Attribute oneIdxAttr = b.getIndexAttr(1);
SmallVector<OpFoldResult> writeStrides(sizes.size(), oneIdxAttr);
SmallVector<OpFoldResult> writeOffsets(sizes.size(), zeroIdxAttr);
build(b, result, source, dest, writeOffsets, sizes, writeStrides, attrs);
}

LogicalResult ParallelInsertSliceOp::verify() {
if (!isa<InParallelOpInterface>(getOperation()->getParentOp()))
return this->emitError("expected InParallelOpInterface parent, got:")
Expand Down