diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td index a403e89a39f98..228317b4adb6c 100644 --- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td +++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td @@ -818,6 +818,7 @@ def Tensor_InsertOp : Tensor_Op<"insert", [ MutableOperandRange getDpsInitsMutable() { return getDestMutable(); } }]; + let hasCanonicalizer = 1; let hasFolder = 1; let hasVerifier = 1; } diff --git a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h index 929a2a7d39649..a1f4ddc1221fc 100644 --- a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h +++ b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h @@ -20,7 +20,9 @@ #include "mlir/IR/AffineMap.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/Location.h" +#include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeRange.h" +#include "mlir/Interfaces/DestinationStyleOpInterface.h" #include "mlir/Support/LLVM.h" // Pull in all enum type definitions and utility function declarations. @@ -158,6 +160,76 @@ Operation *cloneWithoutRegions(OpBuilder &b, Operation *op, SmallVector getPrunedAttributeList(Operation *op, ArrayRef elidedAttrs); +/// Folds cast-like operations into a consuming DestinationStyleOpInterface op +/// if `isPreservingCast` is true. If the cast appears on a 'DPS-init operand', +/// then the tied result type is updated as well to the type of the cast source, +/// and a new cast must be inserted on the new op's result. `createCast` is used +/// to build such required cast ops. +/// +/// ### Example +/// If the `isPreservingCast` returns true if the cast is a "generalizing" +/// `tensor.cast`, then this function would be have as follows: +/// +/// ```mlir +/// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor +/// %2 = dps_op %1 ... : tensor ... +/// ``` +/// +/// folds into: +/// +/// ```mlir +/// %2 = dps_op %0 ... : tensor<8x16xf32> ... +/// ``` +LogicalResult foldCastProducers( + RewriterBase &rewriter, DestinationStyleOpInterface consumerOp, + llvm::function_ref isPreservingCast, + llvm::function_ref + createCast); + +/// Folds `tensor.cast` ops into a consuming DestinationStyleOpInterface op +/// if the casts make their operands less static. See also isPreservingCast +/// above. +template +LogicalResult foldCastProducers(DestinationStyleOpInterface op, + RewriterBase &rewriter) { + return foldCastProducers( + rewriter, op, + [](Operation *castOp) -> bool { + auto concreteCast = dyn_cast(castOp); + if (!concreteCast) + return false; + RankedTensorType resultType = + dyn_cast(concreteCast.getType()); + RankedTensorType sourceType = + dyn_cast(concreteCast->getOperand(0).getType()); + if (!resultType || !sourceType) + return false; + return resultType.isGeneralizationOf(sourceType); + }, + [](RewriterBase &rewriter, Type resultType, Value operand) -> Value { + return rewriter.create(operand.getLoc(), resultType, + operand); + }); +} + +/// A generic pattern for an Operation type that implements +/// DestinationStyleOpInterface, allowing for absorbing cast-like operations +/// that are producers of operands. +template +struct FoldTensorCastIntoConsumerPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(OpType op, + PatternRewriter &rewriter) const override { + DestinationStyleOpInterface dpsOp = + llvm::dyn_cast(op.getOperation()); + if (!dpsOp) + return failure(); + return foldCastProducers(dpsOp, rewriter); + } +}; + } // namespace mlir #endif // MLIR_DIALECT_UTILS_STRUCTUREDOPSUTILS_H diff --git a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td index db38e2e1bce22..cbb20f42f9fb9 100644 --- a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td +++ b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td @@ -116,6 +116,15 @@ def ShapedTypeInterface : TypeInterface<"ShapedType"> { auto clone(::llvm::ArrayRef shape) { return cloneWith(shape, getElementType()); } + + /// Return whether the target shape is a refinement of the source shape. + static bool isShapeRefinementOf( + ArrayRef source, ArrayRef target); + + /// Return whether the target shape is a generalization of the source + /// shape. + static bool isShapeGeneralizationOf( + ArrayRef source, ArrayRef target); }]; let extraSharedClassDeclaration = [{ @@ -185,6 +194,16 @@ def ShapedTypeInterface : TypeInterface<"ShapedType"> { return llvm::count_if($_type.getShape().take_front(index), ::mlir::ShapedType::isDynamic); } + + bool isRefinementOf(ShapedType source) { + return $_type.getElementType() == source.getElementType() && + ShapedType::isShapeRefinementOf(source.getShape(), $_type.getShape()); + } + + bool isGeneralizationOf(ShapedType source) { + return $_type.getElementType() == source.getElementType() && + ShapedType::isShapeGeneralizationOf(source.getShape(), $_type.getShape()); + } }]; } diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td index 4cade83dd3c32..cf9489711aa17 100644 --- a/mlir/include/mlir/IR/BuiltinTypes.td +++ b/mlir/include/mlir/IR/BuiltinTypes.td @@ -837,6 +837,33 @@ def Builtin_RankedTensor : Builtin_Type<"RankedTensor", "tensor", [ using ShapedType::Trait::getDimSize; using ShapedType::Trait::getDynamicDimIndex; + /// Return whether this type is a refinement of `source` with + /// respect to only the shape, meaning they ave the same element type + /// and the shape of this type is the same as source, except + /// zero or more dynamic extents from source have been replaced with + /// static extents. + /// This method is conservative with respect to the encoding. If the + /// encodings are not the same, then false is returned. + bool isRefinementOf(RankedTensorType source) { + + return getEncoding() == source.getEncoding() && + ShapedType::Trait::isRefinementOf( + llvm::cast(source)); + } + + /// Return whether this type is a generalization of `source` with + /// respect to only the shape, meaning they have the same element + /// type and the shape of this type is the same as source, except + /// zero or more static extents have been replaced with unknown + /// extents. + /// This method is conservative with respect to the encoding. If the + /// encodings are not the same, then false is returned. + bool isGeneralizationOf(RankedTensorType source) { + return getEncoding() == source.getEncoding() && + ShapedType::Trait::isGeneralizationOf( + llvm::cast(source)); + } + /// This is a builder type that keeps local references to arguments. /// Arguments that are passed into the builder must outlive the builder. class Builder; diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index e5f83331baf81..549fd37744d71 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -2674,10 +2674,28 @@ FailureOr> SoftmaxOp::decomposeOperation(OpBuilder &b) { // LinalgDialect //===----------------------------------------------------------------------===// +namespace { +struct LinalgAbsorbTensorCastProducersPattern + : public OpInterfaceRewritePattern { + using OpInterfaceRewritePattern::OpInterfaceRewritePattern; + + LogicalResult matchAndRewrite(LinalgOp op, + PatternRewriter &rewriter) const override { + DestinationStyleOpInterface dpsOp = + llvm::dyn_cast(op.getOperation()); + if (!dpsOp) + return failure(); + return foldCastProducers(dpsOp, rewriter); + } +}; +} // namespace + void LinalgDialect::getCanonicalizationPatterns( RewritePatternSet &results) const { - results.add(getContext()); + results + .add( + getContext()); } Operation *LinalgDialect::materializeConstant(OpBuilder &builder, diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp index 4c65045084dc5..38f8a2a288020 100644 --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -14,6 +14,7 @@ #include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/Dialect/Utils/ReshapeOpsUtils.h" #include "mlir/Dialect/Utils/StaticValueUtils.h" +#include "mlir/Dialect/Utils/StructuredOpsUtils.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributeInterfaces.h" #include "mlir/IR/BuiltinTypeInterfaces.h" @@ -1369,6 +1370,11 @@ OpFoldResult InsertOp::fold(FoldAdaptor adaptor) { return {}; } +void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.add>(context); +} + //===----------------------------------------------------------------------===// // GenerateOp //===----------------------------------------------------------------------===// @@ -2413,7 +2419,9 @@ void ExtractSliceOp::getCanonicalizationPatterns(RewritePatternSet &results, results.add< OpWithOffsetSizesAndStridesConstantArgumentFolder< ExtractSliceOp, SliceReturnTypeCanonicalizer, SliceCanonicalizer>, - ExtractSliceOpCastFolder>(context); + ExtractSliceOpCastFolder, + FoldTensorCastIntoConsumerPattern>(context); } // @@ -4154,6 +4162,15 @@ static bool inferStaticShape(PackOp packOp, SmallVectorImpl &srcShape, } LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) { + // pack(cast(x)) -> pack(x) + if (packOp.getSource().getDefiningOp() || + packOp.getDest().getDefiningOp()) { + if (succeeded(foldCastProducers( + cast(packOp.getOperation()), + rewriter))) + return success(); + } + // Fold an unpack(pack(x)) to x. if (auto unPackOp = packOp.getSource().getDefiningOp()) { if (unPackOp.getSourceType() != packOp.getDestType()) @@ -4388,6 +4405,15 @@ static bool inferStaticShape(UnPackOp op, SmallVectorImpl &srcShape, LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp, PatternRewriter &rewriter) { + // pack(cast(x)) -> pack(x) + if (unPackOp.getSource().getDefiningOp() || + unPackOp.getDest().getDefiningOp()) { + if (succeeded(foldCastProducers( + cast(unPackOp.getOperation()), + rewriter))) + return success(); + } + /// pack(unpack(x)) -> x if (PackOp packOp = unPackOp.getSource().getDefiningOp()) { if (packOp.getDestType() != unPackOp.getSourceType()) @@ -4533,9 +4559,7 @@ struct FoldTensorCastProducerOp //===----------------------------------------------------------------------===// void TensorDialect::getCanonicalizationPatterns( - RewritePatternSet &results) const { - results.add(getContext()); -} + RewritePatternSet &results) const {} //===----------------------------------------------------------------------===// // TableGen'd op method definitions diff --git a/mlir/lib/Dialect/Utils/StructuredOpsUtils.cpp b/mlir/lib/Dialect/Utils/StructuredOpsUtils.cpp index adde8a66d8354..131a67cede5aa 100644 --- a/mlir/lib/Dialect/Utils/StructuredOpsUtils.cpp +++ b/mlir/lib/Dialect/Utils/StructuredOpsUtils.cpp @@ -228,3 +228,56 @@ mlir::getPrunedAttributeList(Operation *op, ArrayRef elidedAttrs) { } return attrs; } + +LogicalResult mlir::foldCastProducers( + RewriterBase &rewriter, DestinationStyleOpInterface op, + llvm::function_ref isPreservingCast, + llvm::function_ref + createCast) { + + auto canFoldIntoConsumerOp = [&isPreservingCast](Operation *castOp) { + return castOp && isPreservingCast(castOp); + }; + + // If no operand comes from a tensor::CastOp and can be folded then fail. + bool hasTensorCastOperand = + llvm::any_of(op->getOpOperands(), [&](OpOperand &opOperand) { + if (llvm::isa(opOperand.get())) + return false; + Operation *castOp = opOperand.get().getDefiningOp(); + return castOp && canFoldIntoConsumerOp(castOp); + }); + if (!hasTensorCastOperand) + return failure(); + + SmallVector newResultTypes; + newResultTypes.reserve(op->getNumResults()); + SmallVector newOperands; + newOperands.reserve(op->getNumOperands()); + for (OpOperand &opOperand : op->getOpOperands()) { + Operation *tensorCastOp = opOperand.get().getDefiningOp(); + bool fold = canFoldIntoConsumerOp(tensorCastOp); + newOperands.push_back(fold ? tensorCastOp->getOperand(0) : opOperand.get()); + if (op.isDpsInit(&opOperand) && + !llvm::isa(newOperands.back().getType())) + newResultTypes.push_back(newOperands.back().getType()); + } + + // Clone op. + Operation *newOp = clone(rewriter, op, newResultTypes, newOperands); + SmallVector replacements; + replacements.reserve(newOp->getNumResults()); + for (auto [oldResult, newResult] : + llvm::zip(op->getResults(), newOp->getResults())) { + if (newResult.getType() != oldResult.getType()) { + Value resultCast = createCast(rewriter, oldResult.getType(), newResult); + replacements.push_back(resultCast); + } else { + replacements.push_back(newResult); + } + } + rewriter.replaceOp(op, replacements); + + return success(); +} \ No newline at end of file diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index d10a31941db4f..2b18e17d8d250 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -4184,7 +4184,10 @@ struct TransferReadAfterWriteToBroadcast void TransferReadOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add(context); + results + .add>( + context); } //===----------------------------------------------------------------------===// @@ -4636,7 +4639,10 @@ struct SwapExtractSliceOfTransferWrite void TransferWriteOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add(context); + results + .add>( + context); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/IR/BuiltinTypeInterfaces.cpp b/mlir/lib/IR/BuiltinTypeInterfaces.cpp index ab9e65b5edfed..50a8330d01dc4 100644 --- a/mlir/lib/IR/BuiltinTypeInterfaces.cpp +++ b/mlir/lib/IR/BuiltinTypeInterfaces.cpp @@ -33,3 +33,27 @@ int64_t ShapedType::getNumElements(ArrayRef shape) { } return num; } + +bool ShapedType::isShapeRefinementOf(ArrayRef source, + ArrayRef target) { + if (source.size() != target.size()) + return false; + for (auto [srcDim, tgtDim] : llvm::zip_equal(source, target)) { + // If the source dimension is dynamic, then the target dimension can be + // dynamic or static. + if (isDynamic(srcDim)) + continue; + // Static source dim and dynamic result dim -> not a refinement. + if (isDynamic(tgtDim)) + return false; + // Static source dim != static result dim -> not a refinement. + if (srcDim != tgtDim) + return false; + } + return true; +} + +bool ShapedType::isShapeGeneralizationOf(ArrayRef source, + ArrayRef target) { + return isShapeRefinementOf(target, source); +} diff --git a/mlir/test/Dialect/Bufferization/canonicalize.mlir b/mlir/test/Dialect/Bufferization/canonicalize.mlir index b6c0a0e25efe0..bae1ff805d939 100644 --- a/mlir/test/Dialect/Bufferization/canonicalize.mlir +++ b/mlir/test/Dialect/Bufferization/canonicalize.mlir @@ -388,3 +388,19 @@ func.func @negative_input() -> tensor { %11 = bufferization.alloc_tensor(%c10, %idx-3, %idx27) : tensor return %11 : tensor } + +// ----- + +func.func @materialize_in_destination_tensor_cast(%arg0: tensor<4xf32>, %arg1: index) -> tensor { + %0 = bufferization.alloc_tensor(%arg1) : tensor + %1 = tensor.cast %arg0 : tensor<4xf32> to tensor + %2 = bufferization.materialize_in_destination %1 in %0 : (tensor, tensor) -> tensor + return %2 : tensor +} + +// Check that a `tensor.cast` producer is not absorbed. + +// CHECK-LABEL: func.func @materialize_in_destination_tensor_cast +// CHECK: tensor.cast +// CHECK: bufferization.materialize_in_destination +// CHECK-SAME: : (tensor, tensor) -> tensor