-
Notifications
You must be signed in to change notification settings - Fork 15.3k
[MLIR][Tensor] Fix DPS op canonicalizer with tensor.cast`
#91382
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MLIR][Tensor] Fix DPS op canonicalizer with tensor.cast`
#91382
Conversation
Attempts to address a bug pointed out in llvm#91265 by moving the FoldTensorCastProducerOp canonicalizer definition upward into the MLIRDialectUtils library. Since the MLIRDialectUtils can't depend on any dialect, the canonicalizer had to change slightly, and a templated version is introduced. Then, we need to add this canonicalization routine where it was used before, except for places where it is incorrect as pointed out in the bug. Based on cursory inspection of the TableGen definitions, only `bufferization.materialize_in_destination` should *not* have the canonicalizer, but existing tests passed if the canonicalizer as only added for `tensor.pack|unpack|extract_slice` and the LinalgOp interface.
tensor.cast`tensor.cast`
|
@llvm/pr-subscribers-mlir-bufferization @llvm/pr-subscribers-mlir-core Author: Christopher Bate (christopherbate) ChangesAttempts to address a bug pointed out in #91265 Then, we need to add this canonicalization routine where it was used before, Full diff: https://github.com/llvm/llvm-project/pull/91382.diff 10 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index a403e89a39f98..228317b4adb6c 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -818,6 +818,7 @@ def Tensor_InsertOp : Tensor_Op<"insert", [
MutableOperandRange getDpsInitsMutable() { return getDestMutable(); }
}];
+ let hasCanonicalizer = 1;
let hasFolder = 1;
let hasVerifier = 1;
}
diff --git a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h
index 929a2a7d39649..a1f4ddc1221fc 100644
--- a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h
@@ -20,7 +20,9 @@
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Location.h"
+#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeRange.h"
+#include "mlir/Interfaces/DestinationStyleOpInterface.h"
#include "mlir/Support/LLVM.h"
// Pull in all enum type definitions and utility function declarations.
@@ -158,6 +160,76 @@ Operation *cloneWithoutRegions(OpBuilder &b, Operation *op,
SmallVector<NamedAttribute>
getPrunedAttributeList(Operation *op, ArrayRef<StringRef> elidedAttrs);
+/// Folds cast-like operations into a consuming DestinationStyleOpInterface op
+/// if `isPreservingCast` is true. If the cast appears on a 'DPS-init operand',
+/// then the tied result type is updated as well to the type of the cast source,
+/// and a new cast must be inserted on the new op's result. `createCast` is used
+/// to build such required cast ops.
+///
+/// ### Example
+/// If the `isPreservingCast` returns true if the cast is a "generalizing"
+/// `tensor.cast`, then this function would be have as follows:
+///
+/// ```mlir
+/// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
+/// %2 = dps_op %1 ... : tensor<?x?xf32> ...
+/// ```
+///
+/// folds into:
+///
+/// ```mlir
+/// %2 = dps_op %0 ... : tensor<8x16xf32> ...
+/// ```
+LogicalResult foldCastProducers(
+ RewriterBase &rewriter, DestinationStyleOpInterface consumerOp,
+ llvm::function_ref<bool(Operation *)> isPreservingCast,
+ llvm::function_ref<Value(RewriterBase &rewriter, Type originalType,
+ Value replacement)>
+ createCast);
+
+/// Folds `tensor.cast` ops into a consuming DestinationStyleOpInterface op
+/// if the casts make their operands less static. See also isPreservingCast
+/// above.
+template <typename CastOpType>
+LogicalResult foldCastProducers(DestinationStyleOpInterface op,
+ RewriterBase &rewriter) {
+ return foldCastProducers(
+ rewriter, op,
+ [](Operation *castOp) -> bool {
+ auto concreteCast = dyn_cast<CastOpType>(castOp);
+ if (!concreteCast)
+ return false;
+ RankedTensorType resultType =
+ dyn_cast<RankedTensorType>(concreteCast.getType());
+ RankedTensorType sourceType =
+ dyn_cast<RankedTensorType>(concreteCast->getOperand(0).getType());
+ if (!resultType || !sourceType)
+ return false;
+ return resultType.isGeneralizationOf(sourceType);
+ },
+ [](RewriterBase &rewriter, Type resultType, Value operand) -> Value {
+ return rewriter.create<CastOpType>(operand.getLoc(), resultType,
+ operand);
+ });
+}
+
+/// A generic pattern for an Operation type that implements
+/// DestinationStyleOpInterface, allowing for absorbing cast-like operations
+/// that are producers of operands.
+template <typename OpType, typename CastOpType>
+struct FoldTensorCastIntoConsumerPattern : public OpRewritePattern<OpType> {
+ using OpRewritePattern<OpType>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(OpType op,
+ PatternRewriter &rewriter) const override {
+ DestinationStyleOpInterface dpsOp =
+ llvm::dyn_cast<DestinationStyleOpInterface>(op.getOperation());
+ if (!dpsOp)
+ return failure();
+ return foldCastProducers<CastOpType>(dpsOp, rewriter);
+ }
+};
+
} // namespace mlir
#endif // MLIR_DIALECT_UTILS_STRUCTUREDOPSUTILS_H
diff --git a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
index db38e2e1bce22..cbb20f42f9fb9 100644
--- a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
+++ b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
@@ -116,6 +116,15 @@ def ShapedTypeInterface : TypeInterface<"ShapedType"> {
auto clone(::llvm::ArrayRef<int64_t> shape) {
return cloneWith(shape, getElementType());
}
+
+ /// Return whether the target shape is a refinement of the source shape.
+ static bool isShapeRefinementOf(
+ ArrayRef<int64_t> source, ArrayRef<int64_t> target);
+
+ /// Return whether the target shape is a generalization of the source
+ /// shape.
+ static bool isShapeGeneralizationOf(
+ ArrayRef<int64_t> source, ArrayRef<int64_t> target);
}];
let extraSharedClassDeclaration = [{
@@ -185,6 +194,16 @@ def ShapedTypeInterface : TypeInterface<"ShapedType"> {
return llvm::count_if($_type.getShape().take_front(index),
::mlir::ShapedType::isDynamic);
}
+
+ bool isRefinementOf(ShapedType source) {
+ return $_type.getElementType() == source.getElementType() &&
+ ShapedType::isShapeRefinementOf(source.getShape(), $_type.getShape());
+ }
+
+ bool isGeneralizationOf(ShapedType source) {
+ return $_type.getElementType() == source.getElementType() &&
+ ShapedType::isShapeGeneralizationOf(source.getShape(), $_type.getShape());
+ }
}];
}
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index 4cade83dd3c32..cf9489711aa17 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -837,6 +837,33 @@ def Builtin_RankedTensor : Builtin_Type<"RankedTensor", "tensor", [
using ShapedType::Trait<RankedTensorType>::getDimSize;
using ShapedType::Trait<RankedTensorType>::getDynamicDimIndex;
+ /// Return whether this type is a refinement of `source` with
+ /// respect to only the shape, meaning they ave the same element type
+ /// and the shape of this type is the same as source, except
+ /// zero or more dynamic extents from source have been replaced with
+ /// static extents.
+ /// This method is conservative with respect to the encoding. If the
+ /// encodings are not the same, then false is returned.
+ bool isRefinementOf(RankedTensorType source) {
+
+ return getEncoding() == source.getEncoding() &&
+ ShapedType::Trait<RankedTensorType>::isRefinementOf(
+ llvm::cast<ShapedType>(source));
+ }
+
+ /// Return whether this type is a generalization of `source` with
+ /// respect to only the shape, meaning they have the same element
+ /// type and the shape of this type is the same as source, except
+ /// zero or more static extents have been replaced with unknown
+ /// extents.
+ /// This method is conservative with respect to the encoding. If the
+ /// encodings are not the same, then false is returned.
+ bool isGeneralizationOf(RankedTensorType source) {
+ return getEncoding() == source.getEncoding() &&
+ ShapedType::Trait<RankedTensorType>::isGeneralizationOf(
+ llvm::cast<ShapedType>(source));
+ }
+
/// This is a builder type that keeps local references to arguments.
/// Arguments that are passed into the builder must outlive the builder.
class Builder;
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index e5f83331baf81..549fd37744d71 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -2674,10 +2674,28 @@ FailureOr<SmallVector<Value>> SoftmaxOp::decomposeOperation(OpBuilder &b) {
// LinalgDialect
//===----------------------------------------------------------------------===//
+namespace {
+struct LinalgAbsorbTensorCastProducersPattern
+ : public OpInterfaceRewritePattern<linalg::LinalgOp> {
+ using OpInterfaceRewritePattern<linalg::LinalgOp>::OpInterfaceRewritePattern;
+
+ LogicalResult matchAndRewrite(LinalgOp op,
+ PatternRewriter &rewriter) const override {
+ DestinationStyleOpInterface dpsOp =
+ llvm::dyn_cast<DestinationStyleOpInterface>(op.getOperation());
+ if (!dpsOp)
+ return failure();
+ return foldCastProducers<tensor::CastOp>(dpsOp, rewriter);
+ }
+};
+} // namespace
+
void LinalgDialect::getCanonicalizationPatterns(
RewritePatternSet &results) const {
- results.add<EraseDeadLinalgOp, FoldTensorCastConsumerOp,
- InferStaticShapeOfOperands>(getContext());
+ results
+ .add<EraseDeadLinalgOp, FoldTensorCastConsumerOp,
+ InferStaticShapeOfOperands, LinalgAbsorbTensorCastProducersPattern>(
+ getContext());
}
Operation *LinalgDialect::materializeConstant(OpBuilder &builder,
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 4c65045084dc5..38f8a2a288020 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -14,6 +14,7 @@
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
+#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributeInterfaces.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
@@ -1369,6 +1370,11 @@ OpFoldResult InsertOp::fold(FoldAdaptor adaptor) {
return {};
}
+void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results,
+ MLIRContext *context) {
+ results.add<FoldTensorCastIntoConsumerPattern<InsertOp, CastOp>>(context);
+}
+
//===----------------------------------------------------------------------===//
// GenerateOp
//===----------------------------------------------------------------------===//
@@ -2413,7 +2419,9 @@ void ExtractSliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
results.add<
OpWithOffsetSizesAndStridesConstantArgumentFolder<
ExtractSliceOp, SliceReturnTypeCanonicalizer, SliceCanonicalizer>,
- ExtractSliceOpCastFolder>(context);
+ ExtractSliceOpCastFolder,
+ FoldTensorCastIntoConsumerPattern<tensor::ExtractSliceOp,
+ tensor::CastOp>>(context);
}
//
@@ -4154,6 +4162,15 @@ static bool inferStaticShape(PackOp packOp, SmallVectorImpl<int64_t> &srcShape,
}
LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
+ // pack(cast(x)) -> pack(x)
+ if (packOp.getSource().getDefiningOp<tensor::CastOp>() ||
+ packOp.getDest().getDefiningOp<tensor::CastOp>()) {
+ if (succeeded(foldCastProducers<CastOp>(
+ cast<DestinationStyleOpInterface>(packOp.getOperation()),
+ rewriter)))
+ return success();
+ }
+
// Fold an unpack(pack(x)) to x.
if (auto unPackOp = packOp.getSource().getDefiningOp<UnPackOp>()) {
if (unPackOp.getSourceType() != packOp.getDestType())
@@ -4388,6 +4405,15 @@ static bool inferStaticShape(UnPackOp op, SmallVectorImpl<int64_t> &srcShape,
LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
PatternRewriter &rewriter) {
+ // pack(cast(x)) -> pack(x)
+ if (unPackOp.getSource().getDefiningOp<tensor::CastOp>() ||
+ unPackOp.getDest().getDefiningOp<tensor::CastOp>()) {
+ if (succeeded(foldCastProducers<CastOp>(
+ cast<DestinationStyleOpInterface>(unPackOp.getOperation()),
+ rewriter)))
+ return success();
+ }
+
/// pack(unpack(x)) -> x
if (PackOp packOp = unPackOp.getSource().getDefiningOp<tensor::PackOp>()) {
if (packOp.getDestType() != unPackOp.getSourceType())
@@ -4533,9 +4559,7 @@ struct FoldTensorCastProducerOp
//===----------------------------------------------------------------------===//
void TensorDialect::getCanonicalizationPatterns(
- RewritePatternSet &results) const {
- results.add<FoldTensorCastProducerOp>(getContext());
-}
+ RewritePatternSet &results) const {}
//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
diff --git a/mlir/lib/Dialect/Utils/StructuredOpsUtils.cpp b/mlir/lib/Dialect/Utils/StructuredOpsUtils.cpp
index adde8a66d8354..131a67cede5aa 100644
--- a/mlir/lib/Dialect/Utils/StructuredOpsUtils.cpp
+++ b/mlir/lib/Dialect/Utils/StructuredOpsUtils.cpp
@@ -228,3 +228,56 @@ mlir::getPrunedAttributeList(Operation *op, ArrayRef<StringRef> elidedAttrs) {
}
return attrs;
}
+
+LogicalResult mlir::foldCastProducers(
+ RewriterBase &rewriter, DestinationStyleOpInterface op,
+ llvm::function_ref<bool(Operation *)> isPreservingCast,
+ llvm::function_ref<Value(RewriterBase &rewriter, Type originalType,
+ Value replacement)>
+ createCast) {
+
+ auto canFoldIntoConsumerOp = [&isPreservingCast](Operation *castOp) {
+ return castOp && isPreservingCast(castOp);
+ };
+
+ // If no operand comes from a tensor::CastOp and can be folded then fail.
+ bool hasTensorCastOperand =
+ llvm::any_of(op->getOpOperands(), [&](OpOperand &opOperand) {
+ if (llvm::isa<BlockArgument>(opOperand.get()))
+ return false;
+ Operation *castOp = opOperand.get().getDefiningOp();
+ return castOp && canFoldIntoConsumerOp(castOp);
+ });
+ if (!hasTensorCastOperand)
+ return failure();
+
+ SmallVector<Type, 4> newResultTypes;
+ newResultTypes.reserve(op->getNumResults());
+ SmallVector<Value, 4> newOperands;
+ newOperands.reserve(op->getNumOperands());
+ for (OpOperand &opOperand : op->getOpOperands()) {
+ Operation *tensorCastOp = opOperand.get().getDefiningOp();
+ bool fold = canFoldIntoConsumerOp(tensorCastOp);
+ newOperands.push_back(fold ? tensorCastOp->getOperand(0) : opOperand.get());
+ if (op.isDpsInit(&opOperand) &&
+ !llvm::isa<MemRefType>(newOperands.back().getType()))
+ newResultTypes.push_back(newOperands.back().getType());
+ }
+
+ // Clone op.
+ Operation *newOp = clone(rewriter, op, newResultTypes, newOperands);
+ SmallVector<Value, 4> replacements;
+ replacements.reserve(newOp->getNumResults());
+ for (auto [oldResult, newResult] :
+ llvm::zip(op->getResults(), newOp->getResults())) {
+ if (newResult.getType() != oldResult.getType()) {
+ Value resultCast = createCast(rewriter, oldResult.getType(), newResult);
+ replacements.push_back(resultCast);
+ } else {
+ replacements.push_back(newResult);
+ }
+ }
+ rewriter.replaceOp(op, replacements);
+
+ return success();
+}
\ No newline at end of file
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index d10a31941db4f..2b18e17d8d250 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -4184,7 +4184,10 @@ struct TransferReadAfterWriteToBroadcast
void TransferReadOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<TransferReadAfterWriteToBroadcast>(context);
+ results
+ .add<TransferReadAfterWriteToBroadcast,
+ FoldTensorCastIntoConsumerPattern<TransferReadOp, tensor::CastOp>>(
+ context);
}
//===----------------------------------------------------------------------===//
@@ -4636,7 +4639,10 @@ struct SwapExtractSliceOfTransferWrite
void TransferWriteOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<FoldWaw, SwapExtractSliceOfTransferWrite>(context);
+ results
+ .add<FoldWaw, SwapExtractSliceOfTransferWrite,
+ FoldTensorCastIntoConsumerPattern<TransferWriteOp, tensor::CastOp>>(
+ context);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/IR/BuiltinTypeInterfaces.cpp b/mlir/lib/IR/BuiltinTypeInterfaces.cpp
index ab9e65b5edfed..50a8330d01dc4 100644
--- a/mlir/lib/IR/BuiltinTypeInterfaces.cpp
+++ b/mlir/lib/IR/BuiltinTypeInterfaces.cpp
@@ -33,3 +33,27 @@ int64_t ShapedType::getNumElements(ArrayRef<int64_t> shape) {
}
return num;
}
+
+bool ShapedType::isShapeRefinementOf(ArrayRef<int64_t> source,
+ ArrayRef<int64_t> target) {
+ if (source.size() != target.size())
+ return false;
+ for (auto [srcDim, tgtDim] : llvm::zip_equal(source, target)) {
+ // If the source dimension is dynamic, then the target dimension can be
+ // dynamic or static.
+ if (isDynamic(srcDim))
+ continue;
+ // Static source dim and dynamic result dim -> not a refinement.
+ if (isDynamic(tgtDim))
+ return false;
+ // Static source dim != static result dim -> not a refinement.
+ if (srcDim != tgtDim)
+ return false;
+ }
+ return true;
+}
+
+bool ShapedType::isShapeGeneralizationOf(ArrayRef<int64_t> source,
+ ArrayRef<int64_t> target) {
+ return isShapeRefinementOf(target, source);
+}
diff --git a/mlir/test/Dialect/Bufferization/canonicalize.mlir b/mlir/test/Dialect/Bufferization/canonicalize.mlir
index b6c0a0e25efe0..bae1ff805d939 100644
--- a/mlir/test/Dialect/Bufferization/canonicalize.mlir
+++ b/mlir/test/Dialect/Bufferization/canonicalize.mlir
@@ -388,3 +388,19 @@ func.func @negative_input() -> tensor<?x?x?xf16> {
%11 = bufferization.alloc_tensor(%c10, %idx-3, %idx27) : tensor<?x?x?xf16>
return %11 : tensor<?x?x?xf16>
}
+
+// -----
+
+func.func @materialize_in_destination_tensor_cast(%arg0: tensor<4xf32>, %arg1: index) -> tensor<?xf32> {
+ %0 = bufferization.alloc_tensor(%arg1) : tensor<?xf32>
+ %1 = tensor.cast %arg0 : tensor<4xf32> to tensor<?xf32>
+ %2 = bufferization.materialize_in_destination %1 in %0 : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
+ return %2 : tensor<?xf32>
+}
+
+// Check that a `tensor.cast` producer is not absorbed.
+
+// CHECK-LABEL: func.func @materialize_in_destination_tensor_cast
+// CHECK: tensor.cast
+// CHECK: bufferization.materialize_in_destination
+// CHECK-SAME: : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
|
@llvm/pr-subscribers-mlir-ods Author: Christopher Bate (christopherbate) ChangesAttempts to address a bug pointed out in #91265 Then, we need to add this canonicalization routine where it was used before, Full diff: https://github.com/llvm/llvm-project/pull/91382.diff 10 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index a403e89a39f98..228317b4adb6c 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -818,6 +818,7 @@ def Tensor_InsertOp : Tensor_Op<"insert", [
MutableOperandRange getDpsInitsMutable() { return getDestMutable(); }
}];
+ let hasCanonicalizer = 1;
let hasFolder = 1;
let hasVerifier = 1;
}
diff --git a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h
index 929a2a7d39649..a1f4ddc1221fc 100644
--- a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h
@@ -20,7 +20,9 @@
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Location.h"
+#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeRange.h"
+#include "mlir/Interfaces/DestinationStyleOpInterface.h"
#include "mlir/Support/LLVM.h"
// Pull in all enum type definitions and utility function declarations.
@@ -158,6 +160,76 @@ Operation *cloneWithoutRegions(OpBuilder &b, Operation *op,
SmallVector<NamedAttribute>
getPrunedAttributeList(Operation *op, ArrayRef<StringRef> elidedAttrs);
+/// Folds cast-like operations into a consuming DestinationStyleOpInterface op
+/// if `isPreservingCast` is true. If the cast appears on a 'DPS-init operand',
+/// then the tied result type is updated as well to the type of the cast source,
+/// and a new cast must be inserted on the new op's result. `createCast` is used
+/// to build such required cast ops.
+///
+/// ### Example
+/// If the `isPreservingCast` returns true if the cast is a "generalizing"
+/// `tensor.cast`, then this function would be have as follows:
+///
+/// ```mlir
+/// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
+/// %2 = dps_op %1 ... : tensor<?x?xf32> ...
+/// ```
+///
+/// folds into:
+///
+/// ```mlir
+/// %2 = dps_op %0 ... : tensor<8x16xf32> ...
+/// ```
+LogicalResult foldCastProducers(
+ RewriterBase &rewriter, DestinationStyleOpInterface consumerOp,
+ llvm::function_ref<bool(Operation *)> isPreservingCast,
+ llvm::function_ref<Value(RewriterBase &rewriter, Type originalType,
+ Value replacement)>
+ createCast);
+
+/// Folds `tensor.cast` ops into a consuming DestinationStyleOpInterface op
+/// if the casts make their operands less static. See also isPreservingCast
+/// above.
+template <typename CastOpType>
+LogicalResult foldCastProducers(DestinationStyleOpInterface op,
+ RewriterBase &rewriter) {
+ return foldCastProducers(
+ rewriter, op,
+ [](Operation *castOp) -> bool {
+ auto concreteCast = dyn_cast<CastOpType>(castOp);
+ if (!concreteCast)
+ return false;
+ RankedTensorType resultType =
+ dyn_cast<RankedTensorType>(concreteCast.getType());
+ RankedTensorType sourceType =
+ dyn_cast<RankedTensorType>(concreteCast->getOperand(0).getType());
+ if (!resultType || !sourceType)
+ return false;
+ return resultType.isGeneralizationOf(sourceType);
+ },
+ [](RewriterBase &rewriter, Type resultType, Value operand) -> Value {
+ return rewriter.create<CastOpType>(operand.getLoc(), resultType,
+ operand);
+ });
+}
+
+/// A generic pattern for an Operation type that implements
+/// DestinationStyleOpInterface, allowing for absorbing cast-like operations
+/// that are producers of operands.
+template <typename OpType, typename CastOpType>
+struct FoldTensorCastIntoConsumerPattern : public OpRewritePattern<OpType> {
+ using OpRewritePattern<OpType>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(OpType op,
+ PatternRewriter &rewriter) const override {
+ DestinationStyleOpInterface dpsOp =
+ llvm::dyn_cast<DestinationStyleOpInterface>(op.getOperation());
+ if (!dpsOp)
+ return failure();
+ return foldCastProducers<CastOpType>(dpsOp, rewriter);
+ }
+};
+
} // namespace mlir
#endif // MLIR_DIALECT_UTILS_STRUCTUREDOPSUTILS_H
diff --git a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
index db38e2e1bce22..cbb20f42f9fb9 100644
--- a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
+++ b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
@@ -116,6 +116,15 @@ def ShapedTypeInterface : TypeInterface<"ShapedType"> {
auto clone(::llvm::ArrayRef<int64_t> shape) {
return cloneWith(shape, getElementType());
}
+
+ /// Return whether the target shape is a refinement of the source shape.
+ static bool isShapeRefinementOf(
+ ArrayRef<int64_t> source, ArrayRef<int64_t> target);
+
+ /// Return whether the target shape is a generalization of the source
+ /// shape.
+ static bool isShapeGeneralizationOf(
+ ArrayRef<int64_t> source, ArrayRef<int64_t> target);
}];
let extraSharedClassDeclaration = [{
@@ -185,6 +194,16 @@ def ShapedTypeInterface : TypeInterface<"ShapedType"> {
return llvm::count_if($_type.getShape().take_front(index),
::mlir::ShapedType::isDynamic);
}
+
+ bool isRefinementOf(ShapedType source) {
+ return $_type.getElementType() == source.getElementType() &&
+ ShapedType::isShapeRefinementOf(source.getShape(), $_type.getShape());
+ }
+
+ bool isGeneralizationOf(ShapedType source) {
+ return $_type.getElementType() == source.getElementType() &&
+ ShapedType::isShapeGeneralizationOf(source.getShape(), $_type.getShape());
+ }
}];
}
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index 4cade83dd3c32..cf9489711aa17 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -837,6 +837,33 @@ def Builtin_RankedTensor : Builtin_Type<"RankedTensor", "tensor", [
using ShapedType::Trait<RankedTensorType>::getDimSize;
using ShapedType::Trait<RankedTensorType>::getDynamicDimIndex;
+ /// Return whether this type is a refinement of `source` with
+ /// respect to only the shape, meaning they ave the same element type
+ /// and the shape of this type is the same as source, except
+ /// zero or more dynamic extents from source have been replaced with
+ /// static extents.
+ /// This method is conservative with respect to the encoding. If the
+ /// encodings are not the same, then false is returned.
+ bool isRefinementOf(RankedTensorType source) {
+
+ return getEncoding() == source.getEncoding() &&
+ ShapedType::Trait<RankedTensorType>::isRefinementOf(
+ llvm::cast<ShapedType>(source));
+ }
+
+ /// Return whether this type is a generalization of `source` with
+ /// respect to only the shape, meaning they have the same element
+ /// type and the shape of this type is the same as source, except
+ /// zero or more static extents have been replaced with unknown
+ /// extents.
+ /// This method is conservative with respect to the encoding. If the
+ /// encodings are not the same, then false is returned.
+ bool isGeneralizationOf(RankedTensorType source) {
+ return getEncoding() == source.getEncoding() &&
+ ShapedType::Trait<RankedTensorType>::isGeneralizationOf(
+ llvm::cast<ShapedType>(source));
+ }
+
/// This is a builder type that keeps local references to arguments.
/// Arguments that are passed into the builder must outlive the builder.
class Builder;
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index e5f83331baf81..549fd37744d71 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -2674,10 +2674,28 @@ FailureOr<SmallVector<Value>> SoftmaxOp::decomposeOperation(OpBuilder &b) {
// LinalgDialect
//===----------------------------------------------------------------------===//
+namespace {
+struct LinalgAbsorbTensorCastProducersPattern
+ : public OpInterfaceRewritePattern<linalg::LinalgOp> {
+ using OpInterfaceRewritePattern<linalg::LinalgOp>::OpInterfaceRewritePattern;
+
+ LogicalResult matchAndRewrite(LinalgOp op,
+ PatternRewriter &rewriter) const override {
+ DestinationStyleOpInterface dpsOp =
+ llvm::dyn_cast<DestinationStyleOpInterface>(op.getOperation());
+ if (!dpsOp)
+ return failure();
+ return foldCastProducers<tensor::CastOp>(dpsOp, rewriter);
+ }
+};
+} // namespace
+
void LinalgDialect::getCanonicalizationPatterns(
RewritePatternSet &results) const {
- results.add<EraseDeadLinalgOp, FoldTensorCastConsumerOp,
- InferStaticShapeOfOperands>(getContext());
+ results
+ .add<EraseDeadLinalgOp, FoldTensorCastConsumerOp,
+ InferStaticShapeOfOperands, LinalgAbsorbTensorCastProducersPattern>(
+ getContext());
}
Operation *LinalgDialect::materializeConstant(OpBuilder &builder,
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 4c65045084dc5..38f8a2a288020 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -14,6 +14,7 @@
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
+#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributeInterfaces.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
@@ -1369,6 +1370,11 @@ OpFoldResult InsertOp::fold(FoldAdaptor adaptor) {
return {};
}
+void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results,
+ MLIRContext *context) {
+ results.add<FoldTensorCastIntoConsumerPattern<InsertOp, CastOp>>(context);
+}
+
//===----------------------------------------------------------------------===//
// GenerateOp
//===----------------------------------------------------------------------===//
@@ -2413,7 +2419,9 @@ void ExtractSliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
results.add<
OpWithOffsetSizesAndStridesConstantArgumentFolder<
ExtractSliceOp, SliceReturnTypeCanonicalizer, SliceCanonicalizer>,
- ExtractSliceOpCastFolder>(context);
+ ExtractSliceOpCastFolder,
+ FoldTensorCastIntoConsumerPattern<tensor::ExtractSliceOp,
+ tensor::CastOp>>(context);
}
//
@@ -4154,6 +4162,15 @@ static bool inferStaticShape(PackOp packOp, SmallVectorImpl<int64_t> &srcShape,
}
LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
+ // pack(cast(x)) -> pack(x)
+ if (packOp.getSource().getDefiningOp<tensor::CastOp>() ||
+ packOp.getDest().getDefiningOp<tensor::CastOp>()) {
+ if (succeeded(foldCastProducers<CastOp>(
+ cast<DestinationStyleOpInterface>(packOp.getOperation()),
+ rewriter)))
+ return success();
+ }
+
// Fold an unpack(pack(x)) to x.
if (auto unPackOp = packOp.getSource().getDefiningOp<UnPackOp>()) {
if (unPackOp.getSourceType() != packOp.getDestType())
@@ -4388,6 +4405,15 @@ static bool inferStaticShape(UnPackOp op, SmallVectorImpl<int64_t> &srcShape,
LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
PatternRewriter &rewriter) {
+ // pack(cast(x)) -> pack(x)
+ if (unPackOp.getSource().getDefiningOp<tensor::CastOp>() ||
+ unPackOp.getDest().getDefiningOp<tensor::CastOp>()) {
+ if (succeeded(foldCastProducers<CastOp>(
+ cast<DestinationStyleOpInterface>(unPackOp.getOperation()),
+ rewriter)))
+ return success();
+ }
+
/// pack(unpack(x)) -> x
if (PackOp packOp = unPackOp.getSource().getDefiningOp<tensor::PackOp>()) {
if (packOp.getDestType() != unPackOp.getSourceType())
@@ -4533,9 +4559,7 @@ struct FoldTensorCastProducerOp
//===----------------------------------------------------------------------===//
void TensorDialect::getCanonicalizationPatterns(
- RewritePatternSet &results) const {
- results.add<FoldTensorCastProducerOp>(getContext());
-}
+ RewritePatternSet &results) const {}
//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
diff --git a/mlir/lib/Dialect/Utils/StructuredOpsUtils.cpp b/mlir/lib/Dialect/Utils/StructuredOpsUtils.cpp
index adde8a66d8354..131a67cede5aa 100644
--- a/mlir/lib/Dialect/Utils/StructuredOpsUtils.cpp
+++ b/mlir/lib/Dialect/Utils/StructuredOpsUtils.cpp
@@ -228,3 +228,56 @@ mlir::getPrunedAttributeList(Operation *op, ArrayRef<StringRef> elidedAttrs) {
}
return attrs;
}
+
+LogicalResult mlir::foldCastProducers(
+ RewriterBase &rewriter, DestinationStyleOpInterface op,
+ llvm::function_ref<bool(Operation *)> isPreservingCast,
+ llvm::function_ref<Value(RewriterBase &rewriter, Type originalType,
+ Value replacement)>
+ createCast) {
+
+ auto canFoldIntoConsumerOp = [&isPreservingCast](Operation *castOp) {
+ return castOp && isPreservingCast(castOp);
+ };
+
+ // If no operand comes from a tensor::CastOp and can be folded then fail.
+ bool hasTensorCastOperand =
+ llvm::any_of(op->getOpOperands(), [&](OpOperand &opOperand) {
+ if (llvm::isa<BlockArgument>(opOperand.get()))
+ return false;
+ Operation *castOp = opOperand.get().getDefiningOp();
+ return castOp && canFoldIntoConsumerOp(castOp);
+ });
+ if (!hasTensorCastOperand)
+ return failure();
+
+ SmallVector<Type, 4> newResultTypes;
+ newResultTypes.reserve(op->getNumResults());
+ SmallVector<Value, 4> newOperands;
+ newOperands.reserve(op->getNumOperands());
+ for (OpOperand &opOperand : op->getOpOperands()) {
+ Operation *tensorCastOp = opOperand.get().getDefiningOp();
+ bool fold = canFoldIntoConsumerOp(tensorCastOp);
+ newOperands.push_back(fold ? tensorCastOp->getOperand(0) : opOperand.get());
+ if (op.isDpsInit(&opOperand) &&
+ !llvm::isa<MemRefType>(newOperands.back().getType()))
+ newResultTypes.push_back(newOperands.back().getType());
+ }
+
+ // Clone op.
+ Operation *newOp = clone(rewriter, op, newResultTypes, newOperands);
+ SmallVector<Value, 4> replacements;
+ replacements.reserve(newOp->getNumResults());
+ for (auto [oldResult, newResult] :
+ llvm::zip(op->getResults(), newOp->getResults())) {
+ if (newResult.getType() != oldResult.getType()) {
+ Value resultCast = createCast(rewriter, oldResult.getType(), newResult);
+ replacements.push_back(resultCast);
+ } else {
+ replacements.push_back(newResult);
+ }
+ }
+ rewriter.replaceOp(op, replacements);
+
+ return success();
+}
\ No newline at end of file
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index d10a31941db4f..2b18e17d8d250 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -4184,7 +4184,10 @@ struct TransferReadAfterWriteToBroadcast
void TransferReadOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<TransferReadAfterWriteToBroadcast>(context);
+ results
+ .add<TransferReadAfterWriteToBroadcast,
+ FoldTensorCastIntoConsumerPattern<TransferReadOp, tensor::CastOp>>(
+ context);
}
//===----------------------------------------------------------------------===//
@@ -4636,7 +4639,10 @@ struct SwapExtractSliceOfTransferWrite
void TransferWriteOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<FoldWaw, SwapExtractSliceOfTransferWrite>(context);
+ results
+ .add<FoldWaw, SwapExtractSliceOfTransferWrite,
+ FoldTensorCastIntoConsumerPattern<TransferWriteOp, tensor::CastOp>>(
+ context);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/IR/BuiltinTypeInterfaces.cpp b/mlir/lib/IR/BuiltinTypeInterfaces.cpp
index ab9e65b5edfed..50a8330d01dc4 100644
--- a/mlir/lib/IR/BuiltinTypeInterfaces.cpp
+++ b/mlir/lib/IR/BuiltinTypeInterfaces.cpp
@@ -33,3 +33,27 @@ int64_t ShapedType::getNumElements(ArrayRef<int64_t> shape) {
}
return num;
}
+
+bool ShapedType::isShapeRefinementOf(ArrayRef<int64_t> source,
+ ArrayRef<int64_t> target) {
+ if (source.size() != target.size())
+ return false;
+ for (auto [srcDim, tgtDim] : llvm::zip_equal(source, target)) {
+ // If the source dimension is dynamic, then the target dimension can be
+ // dynamic or static.
+ if (isDynamic(srcDim))
+ continue;
+ // Static source dim and dynamic result dim -> not a refinement.
+ if (isDynamic(tgtDim))
+ return false;
+ // Static source dim != static result dim -> not a refinement.
+ if (srcDim != tgtDim)
+ return false;
+ }
+ return true;
+}
+
+bool ShapedType::isShapeGeneralizationOf(ArrayRef<int64_t> source,
+ ArrayRef<int64_t> target) {
+ return isShapeRefinementOf(target, source);
+}
diff --git a/mlir/test/Dialect/Bufferization/canonicalize.mlir b/mlir/test/Dialect/Bufferization/canonicalize.mlir
index b6c0a0e25efe0..bae1ff805d939 100644
--- a/mlir/test/Dialect/Bufferization/canonicalize.mlir
+++ b/mlir/test/Dialect/Bufferization/canonicalize.mlir
@@ -388,3 +388,19 @@ func.func @negative_input() -> tensor<?x?x?xf16> {
%11 = bufferization.alloc_tensor(%c10, %idx-3, %idx27) : tensor<?x?x?xf16>
return %11 : tensor<?x?x?xf16>
}
+
+// -----
+
+func.func @materialize_in_destination_tensor_cast(%arg0: tensor<4xf32>, %arg1: index) -> tensor<?xf32> {
+ %0 = bufferization.alloc_tensor(%arg1) : tensor<?xf32>
+ %1 = tensor.cast %arg0 : tensor<4xf32> to tensor<?xf32>
+ %2 = bufferization.materialize_in_destination %1 in %0 : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
+ return %2 : tensor<?xf32>
+}
+
+// Check that a `tensor.cast` producer is not absorbed.
+
+// CHECK-LABEL: func.func @materialize_in_destination_tensor_cast
+// CHECK: tensor.cast
+// CHECK: bufferization.materialize_in_destination
+// CHECK-SAME: : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
|
@llvm/pr-subscribers-mlir-linalg Author: Christopher Bate (christopherbate) ChangesAttempts to address a bug pointed out in #91265 Then, we need to add this canonicalization routine where it was used before, Full diff: https://github.com/llvm/llvm-project/pull/91382.diff 10 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index a403e89a39f98..228317b4adb6c 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -818,6 +818,7 @@ def Tensor_InsertOp : Tensor_Op<"insert", [
MutableOperandRange getDpsInitsMutable() { return getDestMutable(); }
}];
+ let hasCanonicalizer = 1;
let hasFolder = 1;
let hasVerifier = 1;
}
diff --git a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h
index 929a2a7d39649..a1f4ddc1221fc 100644
--- a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h
@@ -20,7 +20,9 @@
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Location.h"
+#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeRange.h"
+#include "mlir/Interfaces/DestinationStyleOpInterface.h"
#include "mlir/Support/LLVM.h"
// Pull in all enum type definitions and utility function declarations.
@@ -158,6 +160,76 @@ Operation *cloneWithoutRegions(OpBuilder &b, Operation *op,
SmallVector<NamedAttribute>
getPrunedAttributeList(Operation *op, ArrayRef<StringRef> elidedAttrs);
+/// Folds cast-like operations into a consuming DestinationStyleOpInterface op
+/// if `isPreservingCast` is true. If the cast appears on a 'DPS-init operand',
+/// then the tied result type is updated as well to the type of the cast source,
+/// and a new cast must be inserted on the new op's result. `createCast` is used
+/// to build such required cast ops.
+///
+/// ### Example
+/// If the `isPreservingCast` returns true if the cast is a "generalizing"
+/// `tensor.cast`, then this function would be have as follows:
+///
+/// ```mlir
+/// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
+/// %2 = dps_op %1 ... : tensor<?x?xf32> ...
+/// ```
+///
+/// folds into:
+///
+/// ```mlir
+/// %2 = dps_op %0 ... : tensor<8x16xf32> ...
+/// ```
+LogicalResult foldCastProducers(
+ RewriterBase &rewriter, DestinationStyleOpInterface consumerOp,
+ llvm::function_ref<bool(Operation *)> isPreservingCast,
+ llvm::function_ref<Value(RewriterBase &rewriter, Type originalType,
+ Value replacement)>
+ createCast);
+
+/// Folds `tensor.cast` ops into a consuming DestinationStyleOpInterface op
+/// if the casts make their operands less static. See also isPreservingCast
+/// above.
+template <typename CastOpType>
+LogicalResult foldCastProducers(DestinationStyleOpInterface op,
+ RewriterBase &rewriter) {
+ return foldCastProducers(
+ rewriter, op,
+ [](Operation *castOp) -> bool {
+ auto concreteCast = dyn_cast<CastOpType>(castOp);
+ if (!concreteCast)
+ return false;
+ RankedTensorType resultType =
+ dyn_cast<RankedTensorType>(concreteCast.getType());
+ RankedTensorType sourceType =
+ dyn_cast<RankedTensorType>(concreteCast->getOperand(0).getType());
+ if (!resultType || !sourceType)
+ return false;
+ return resultType.isGeneralizationOf(sourceType);
+ },
+ [](RewriterBase &rewriter, Type resultType, Value operand) -> Value {
+ return rewriter.create<CastOpType>(operand.getLoc(), resultType,
+ operand);
+ });
+}
+
+/// A generic pattern for an Operation type that implements
+/// DestinationStyleOpInterface, allowing for absorbing cast-like operations
+/// that are producers of operands.
+template <typename OpType, typename CastOpType>
+struct FoldTensorCastIntoConsumerPattern : public OpRewritePattern<OpType> {
+ using OpRewritePattern<OpType>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(OpType op,
+ PatternRewriter &rewriter) const override {
+ DestinationStyleOpInterface dpsOp =
+ llvm::dyn_cast<DestinationStyleOpInterface>(op.getOperation());
+ if (!dpsOp)
+ return failure();
+ return foldCastProducers<CastOpType>(dpsOp, rewriter);
+ }
+};
+
} // namespace mlir
#endif // MLIR_DIALECT_UTILS_STRUCTUREDOPSUTILS_H
diff --git a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
index db38e2e1bce22..cbb20f42f9fb9 100644
--- a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
+++ b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
@@ -116,6 +116,15 @@ def ShapedTypeInterface : TypeInterface<"ShapedType"> {
auto clone(::llvm::ArrayRef<int64_t> shape) {
return cloneWith(shape, getElementType());
}
+
+ /// Return whether the target shape is a refinement of the source shape.
+ static bool isShapeRefinementOf(
+ ArrayRef<int64_t> source, ArrayRef<int64_t> target);
+
+ /// Return whether the target shape is a generalization of the source
+ /// shape.
+ static bool isShapeGeneralizationOf(
+ ArrayRef<int64_t> source, ArrayRef<int64_t> target);
}];
let extraSharedClassDeclaration = [{
@@ -185,6 +194,16 @@ def ShapedTypeInterface : TypeInterface<"ShapedType"> {
return llvm::count_if($_type.getShape().take_front(index),
::mlir::ShapedType::isDynamic);
}
+
+ bool isRefinementOf(ShapedType source) {
+ return $_type.getElementType() == source.getElementType() &&
+ ShapedType::isShapeRefinementOf(source.getShape(), $_type.getShape());
+ }
+
+ bool isGeneralizationOf(ShapedType source) {
+ return $_type.getElementType() == source.getElementType() &&
+ ShapedType::isShapeGeneralizationOf(source.getShape(), $_type.getShape());
+ }
}];
}
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index 4cade83dd3c32..cf9489711aa17 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -837,6 +837,33 @@ def Builtin_RankedTensor : Builtin_Type<"RankedTensor", "tensor", [
using ShapedType::Trait<RankedTensorType>::getDimSize;
using ShapedType::Trait<RankedTensorType>::getDynamicDimIndex;
+ /// Return whether this type is a refinement of `source` with
+ /// respect to only the shape, meaning they ave the same element type
+ /// and the shape of this type is the same as source, except
+ /// zero or more dynamic extents from source have been replaced with
+ /// static extents.
+ /// This method is conservative with respect to the encoding. If the
+ /// encodings are not the same, then false is returned.
+ bool isRefinementOf(RankedTensorType source) {
+
+ return getEncoding() == source.getEncoding() &&
+ ShapedType::Trait<RankedTensorType>::isRefinementOf(
+ llvm::cast<ShapedType>(source));
+ }
+
+ /// Return whether this type is a generalization of `source` with
+ /// respect to only the shape, meaning they have the same element
+ /// type and the shape of this type is the same as source, except
+ /// zero or more static extents have been replaced with unknown
+ /// extents.
+ /// This method is conservative with respect to the encoding. If the
+ /// encodings are not the same, then false is returned.
+ bool isGeneralizationOf(RankedTensorType source) {
+ return getEncoding() == source.getEncoding() &&
+ ShapedType::Trait<RankedTensorType>::isGeneralizationOf(
+ llvm::cast<ShapedType>(source));
+ }
+
/// This is a builder type that keeps local references to arguments.
/// Arguments that are passed into the builder must outlive the builder.
class Builder;
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index e5f83331baf81..549fd37744d71 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -2674,10 +2674,28 @@ FailureOr<SmallVector<Value>> SoftmaxOp::decomposeOperation(OpBuilder &b) {
// LinalgDialect
//===----------------------------------------------------------------------===//
+namespace {
+struct LinalgAbsorbTensorCastProducersPattern
+ : public OpInterfaceRewritePattern<linalg::LinalgOp> {
+ using OpInterfaceRewritePattern<linalg::LinalgOp>::OpInterfaceRewritePattern;
+
+ LogicalResult matchAndRewrite(LinalgOp op,
+ PatternRewriter &rewriter) const override {
+ DestinationStyleOpInterface dpsOp =
+ llvm::dyn_cast<DestinationStyleOpInterface>(op.getOperation());
+ if (!dpsOp)
+ return failure();
+ return foldCastProducers<tensor::CastOp>(dpsOp, rewriter);
+ }
+};
+} // namespace
+
void LinalgDialect::getCanonicalizationPatterns(
RewritePatternSet &results) const {
- results.add<EraseDeadLinalgOp, FoldTensorCastConsumerOp,
- InferStaticShapeOfOperands>(getContext());
+ results
+ .add<EraseDeadLinalgOp, FoldTensorCastConsumerOp,
+ InferStaticShapeOfOperands, LinalgAbsorbTensorCastProducersPattern>(
+ getContext());
}
Operation *LinalgDialect::materializeConstant(OpBuilder &builder,
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 4c65045084dc5..38f8a2a288020 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -14,6 +14,7 @@
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
+#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributeInterfaces.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
@@ -1369,6 +1370,11 @@ OpFoldResult InsertOp::fold(FoldAdaptor adaptor) {
return {};
}
+void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results,
+ MLIRContext *context) {
+ results.add<FoldTensorCastIntoConsumerPattern<InsertOp, CastOp>>(context);
+}
+
//===----------------------------------------------------------------------===//
// GenerateOp
//===----------------------------------------------------------------------===//
@@ -2413,7 +2419,9 @@ void ExtractSliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
results.add<
OpWithOffsetSizesAndStridesConstantArgumentFolder<
ExtractSliceOp, SliceReturnTypeCanonicalizer, SliceCanonicalizer>,
- ExtractSliceOpCastFolder>(context);
+ ExtractSliceOpCastFolder,
+ FoldTensorCastIntoConsumerPattern<tensor::ExtractSliceOp,
+ tensor::CastOp>>(context);
}
//
@@ -4154,6 +4162,15 @@ static bool inferStaticShape(PackOp packOp, SmallVectorImpl<int64_t> &srcShape,
}
LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
+ // pack(cast(x)) -> pack(x)
+ if (packOp.getSource().getDefiningOp<tensor::CastOp>() ||
+ packOp.getDest().getDefiningOp<tensor::CastOp>()) {
+ if (succeeded(foldCastProducers<CastOp>(
+ cast<DestinationStyleOpInterface>(packOp.getOperation()),
+ rewriter)))
+ return success();
+ }
+
// Fold an unpack(pack(x)) to x.
if (auto unPackOp = packOp.getSource().getDefiningOp<UnPackOp>()) {
if (unPackOp.getSourceType() != packOp.getDestType())
@@ -4388,6 +4405,15 @@ static bool inferStaticShape(UnPackOp op, SmallVectorImpl<int64_t> &srcShape,
LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
PatternRewriter &rewriter) {
+ // pack(cast(x)) -> pack(x)
+ if (unPackOp.getSource().getDefiningOp<tensor::CastOp>() ||
+ unPackOp.getDest().getDefiningOp<tensor::CastOp>()) {
+ if (succeeded(foldCastProducers<CastOp>(
+ cast<DestinationStyleOpInterface>(unPackOp.getOperation()),
+ rewriter)))
+ return success();
+ }
+
/// pack(unpack(x)) -> x
if (PackOp packOp = unPackOp.getSource().getDefiningOp<tensor::PackOp>()) {
if (packOp.getDestType() != unPackOp.getSourceType())
@@ -4533,9 +4559,7 @@ struct FoldTensorCastProducerOp
//===----------------------------------------------------------------------===//
void TensorDialect::getCanonicalizationPatterns(
- RewritePatternSet &results) const {
- results.add<FoldTensorCastProducerOp>(getContext());
-}
+ RewritePatternSet &results) const {}
//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
diff --git a/mlir/lib/Dialect/Utils/StructuredOpsUtils.cpp b/mlir/lib/Dialect/Utils/StructuredOpsUtils.cpp
index adde8a66d8354..131a67cede5aa 100644
--- a/mlir/lib/Dialect/Utils/StructuredOpsUtils.cpp
+++ b/mlir/lib/Dialect/Utils/StructuredOpsUtils.cpp
@@ -228,3 +228,56 @@ mlir::getPrunedAttributeList(Operation *op, ArrayRef<StringRef> elidedAttrs) {
}
return attrs;
}
+
+LogicalResult mlir::foldCastProducers(
+ RewriterBase &rewriter, DestinationStyleOpInterface op,
+ llvm::function_ref<bool(Operation *)> isPreservingCast,
+ llvm::function_ref<Value(RewriterBase &rewriter, Type originalType,
+ Value replacement)>
+ createCast) {
+
+ auto canFoldIntoConsumerOp = [&isPreservingCast](Operation *castOp) {
+ return castOp && isPreservingCast(castOp);
+ };
+
+ // If no operand comes from a tensor::CastOp and can be folded then fail.
+ bool hasTensorCastOperand =
+ llvm::any_of(op->getOpOperands(), [&](OpOperand &opOperand) {
+ if (llvm::isa<BlockArgument>(opOperand.get()))
+ return false;
+ Operation *castOp = opOperand.get().getDefiningOp();
+ return castOp && canFoldIntoConsumerOp(castOp);
+ });
+ if (!hasTensorCastOperand)
+ return failure();
+
+ SmallVector<Type, 4> newResultTypes;
+ newResultTypes.reserve(op->getNumResults());
+ SmallVector<Value, 4> newOperands;
+ newOperands.reserve(op->getNumOperands());
+ for (OpOperand &opOperand : op->getOpOperands()) {
+ Operation *tensorCastOp = opOperand.get().getDefiningOp();
+ bool fold = canFoldIntoConsumerOp(tensorCastOp);
+ newOperands.push_back(fold ? tensorCastOp->getOperand(0) : opOperand.get());
+ if (op.isDpsInit(&opOperand) &&
+ !llvm::isa<MemRefType>(newOperands.back().getType()))
+ newResultTypes.push_back(newOperands.back().getType());
+ }
+
+ // Clone op.
+ Operation *newOp = clone(rewriter, op, newResultTypes, newOperands);
+ SmallVector<Value, 4> replacements;
+ replacements.reserve(newOp->getNumResults());
+ for (auto [oldResult, newResult] :
+ llvm::zip(op->getResults(), newOp->getResults())) {
+ if (newResult.getType() != oldResult.getType()) {
+ Value resultCast = createCast(rewriter, oldResult.getType(), newResult);
+ replacements.push_back(resultCast);
+ } else {
+ replacements.push_back(newResult);
+ }
+ }
+ rewriter.replaceOp(op, replacements);
+
+ return success();
+}
\ No newline at end of file
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index d10a31941db4f..2b18e17d8d250 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -4184,7 +4184,10 @@ struct TransferReadAfterWriteToBroadcast
void TransferReadOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<TransferReadAfterWriteToBroadcast>(context);
+ results
+ .add<TransferReadAfterWriteToBroadcast,
+ FoldTensorCastIntoConsumerPattern<TransferReadOp, tensor::CastOp>>(
+ context);
}
//===----------------------------------------------------------------------===//
@@ -4636,7 +4639,10 @@ struct SwapExtractSliceOfTransferWrite
void TransferWriteOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<FoldWaw, SwapExtractSliceOfTransferWrite>(context);
+ results
+ .add<FoldWaw, SwapExtractSliceOfTransferWrite,
+ FoldTensorCastIntoConsumerPattern<TransferWriteOp, tensor::CastOp>>(
+ context);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/IR/BuiltinTypeInterfaces.cpp b/mlir/lib/IR/BuiltinTypeInterfaces.cpp
index ab9e65b5edfed..50a8330d01dc4 100644
--- a/mlir/lib/IR/BuiltinTypeInterfaces.cpp
+++ b/mlir/lib/IR/BuiltinTypeInterfaces.cpp
@@ -33,3 +33,27 @@ int64_t ShapedType::getNumElements(ArrayRef<int64_t> shape) {
}
return num;
}
+
+bool ShapedType::isShapeRefinementOf(ArrayRef<int64_t> source,
+ ArrayRef<int64_t> target) {
+ if (source.size() != target.size())
+ return false;
+ for (auto [srcDim, tgtDim] : llvm::zip_equal(source, target)) {
+ // If the source dimension is dynamic, then the target dimension can be
+ // dynamic or static.
+ if (isDynamic(srcDim))
+ continue;
+ // Static source dim and dynamic result dim -> not a refinement.
+ if (isDynamic(tgtDim))
+ return false;
+ // Static source dim != static result dim -> not a refinement.
+ if (srcDim != tgtDim)
+ return false;
+ }
+ return true;
+}
+
+bool ShapedType::isShapeGeneralizationOf(ArrayRef<int64_t> source,
+ ArrayRef<int64_t> target) {
+ return isShapeRefinementOf(target, source);
+}
diff --git a/mlir/test/Dialect/Bufferization/canonicalize.mlir b/mlir/test/Dialect/Bufferization/canonicalize.mlir
index b6c0a0e25efe0..bae1ff805d939 100644
--- a/mlir/test/Dialect/Bufferization/canonicalize.mlir
+++ b/mlir/test/Dialect/Bufferization/canonicalize.mlir
@@ -388,3 +388,19 @@ func.func @negative_input() -> tensor<?x?x?xf16> {
%11 = bufferization.alloc_tensor(%c10, %idx-3, %idx27) : tensor<?x?x?xf16>
return %11 : tensor<?x?x?xf16>
}
+
+// -----
+
+func.func @materialize_in_destination_tensor_cast(%arg0: tensor<4xf32>, %arg1: index) -> tensor<?xf32> {
+ %0 = bufferization.alloc_tensor(%arg1) : tensor<?xf32>
+ %1 = tensor.cast %arg0 : tensor<4xf32> to tensor<?xf32>
+ %2 = bufferization.materialize_in_destination %1 in %0 : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
+ return %2 : tensor<?xf32>
+}
+
+// Check that a `tensor.cast` producer is not absorbed.
+
+// CHECK-LABEL: func.func @materialize_in_destination_tensor_cast
+// CHECK: tensor.cast
+// CHECK: bufferization.materialize_in_destination
+// CHECK-SAME: : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
|
Commit title incorrectly says |
|
Also just realized I forgot to remove old canonicalizer this code is replacing, will update later today |
MaheshRavishankar
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Havent looked at the change yet, but is a much bigger change than I would have anticipated.
Wouldnt a solution be you create a new operation where the result type is also of the same type as the new destination?
MaheshRavishankar
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
On further look, it seems to be doing that... Ok ill take a look in a bit.
It's mostly just code movement. I'll push an update to remove the old pattern so that's more clear.
That's not the problem/solution. The old pattern already did that for DPS init operands. The problem, as described in the linked issue, is that you can't assume any DPS operation can absorb casts on its operands (eve the DPS input operands). A concrete op may have arbitrary logic in its verifier that would make that invalid. A contrived example is I could create an op in a downstream project where the shapes of all operands have to by dynamic. |
MaheshRavishankar
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Haven't had to a chance to get back to this yet but one high level comment
| /// DestinationStyleOpInterface, allowing for absorbing cast-like operations | ||
| /// that are producers of operands. | ||
| template <typename OpType, typename CastOpType> | ||
| struct FoldTensorCastIntoConsumerPattern : public OpRewritePattern<OpType> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we try to not have template patterns be part of header files this way. I see why you have this, but is there a way to avoid doing this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can duplicate it in dialects where it is used. It's not much code, so that would seem fine. I do recall seeing templated patterns like this in ReshapeOpsUtils.h though.
Attempts to address a bug pointed out in #91265
by moving the FoldTensorCastProducerOp canonicalizer definition upward into
the MLIRDialectUtils library. Since the MLIRDialectUtils can't depend on any
dialect, the canonicalizer had to change slightly, and a templated version
is introduced.
Then, we need to add this canonicalization routine where it was used before,
except for places where it is incorrect as pointed out in the bug.
Based on cursory inspection of the TableGen definitions, only
bufferization.materialize_in_destinationshould not have thecanonicalizer, but existing tests passed if the canonicalizer as
only added for
tensor.pack|unpack|extract_sliceand the LinalgOp interface.