Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -818,6 +818,7 @@ def Tensor_InsertOp : Tensor_Op<"insert", [
MutableOperandRange getDpsInitsMutable() { return getDestMutable(); }
}];

let hasCanonicalizer = 1;
let hasFolder = 1;
let hasVerifier = 1;
}
Expand Down
72 changes: 72 additions & 0 deletions mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeRange.h"
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
#include "mlir/Support/LLVM.h"

// Pull in all enum type definitions and utility function declarations.
Expand Down Expand Up @@ -158,6 +160,76 @@ Operation *cloneWithoutRegions(OpBuilder &b, Operation *op,
SmallVector<NamedAttribute>
getPrunedAttributeList(Operation *op, ArrayRef<StringRef> elidedAttrs);

/// Folds cast-like operations into a consuming DestinationStyleOpInterface op
/// if `isPreservingCast` is true. If the cast appears on a 'DPS-init operand',
/// then the tied result type is updated as well to the type of the cast source,
/// and a new cast must be inserted on the new op's result. `createCast` is used
/// to build such required cast ops.
///
/// ### Example
/// If the `isPreservingCast` returns true if the cast is a "generalizing"
/// `tensor.cast`, then this function would be have as follows:
///
/// ```mlir
/// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
/// %2 = dps_op %1 ... : tensor<?x?xf32> ...
/// ```
///
/// folds into:
///
/// ```mlir
/// %2 = dps_op %0 ... : tensor<8x16xf32> ...
/// ```
LogicalResult foldCastProducers(
RewriterBase &rewriter, DestinationStyleOpInterface consumerOp,
llvm::function_ref<bool(Operation *)> isPreservingCast,
llvm::function_ref<Value(RewriterBase &rewriter, Type originalType,
Value replacement)>
createCast);

/// Folds `tensor.cast` ops into a consuming DestinationStyleOpInterface op
/// if the casts make their operands less static. See also isPreservingCast
/// above.
template <typename CastOpType>
LogicalResult foldCastProducers(DestinationStyleOpInterface op,
RewriterBase &rewriter) {
return foldCastProducers(
rewriter, op,
[](Operation *castOp) -> bool {
auto concreteCast = dyn_cast<CastOpType>(castOp);
if (!concreteCast)
return false;
RankedTensorType resultType =
dyn_cast<RankedTensorType>(concreteCast.getType());
RankedTensorType sourceType =
dyn_cast<RankedTensorType>(concreteCast->getOperand(0).getType());
if (!resultType || !sourceType)
return false;
return resultType.isGeneralizationOf(sourceType);
},
[](RewriterBase &rewriter, Type resultType, Value operand) -> Value {
return rewriter.create<CastOpType>(operand.getLoc(), resultType,
operand);
});
}

/// A generic pattern for an Operation type that implements
/// DestinationStyleOpInterface, allowing for absorbing cast-like operations
/// that are producers of operands.
template <typename OpType, typename CastOpType>
struct FoldTensorCastIntoConsumerPattern : public OpRewritePattern<OpType> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we try to not have template patterns be part of header files this way. I see why you have this, but is there a way to avoid doing this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can duplicate it in dialects where it is used. It's not much code, so that would seem fine. I do recall seeing templated patterns like this in ReshapeOpsUtils.h though.

using OpRewritePattern<OpType>::OpRewritePattern;

LogicalResult matchAndRewrite(OpType op,
PatternRewriter &rewriter) const override {
DestinationStyleOpInterface dpsOp =
llvm::dyn_cast<DestinationStyleOpInterface>(op.getOperation());
if (!dpsOp)
return failure();
return foldCastProducers<CastOpType>(dpsOp, rewriter);
}
};

} // namespace mlir

#endif // MLIR_DIALECT_UTILS_STRUCTUREDOPSUTILS_H
19 changes: 19 additions & 0 deletions mlir/include/mlir/IR/BuiltinTypeInterfaces.td
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,15 @@ def ShapedTypeInterface : TypeInterface<"ShapedType"> {
auto clone(::llvm::ArrayRef<int64_t> shape) {
return cloneWith(shape, getElementType());
}

/// Return whether the target shape is a refinement of the source shape.
static bool isShapeRefinementOf(
ArrayRef<int64_t> source, ArrayRef<int64_t> target);

/// Return whether the target shape is a generalization of the source
/// shape.
static bool isShapeGeneralizationOf(
ArrayRef<int64_t> source, ArrayRef<int64_t> target);
}];

let extraSharedClassDeclaration = [{
Expand Down Expand Up @@ -185,6 +194,16 @@ def ShapedTypeInterface : TypeInterface<"ShapedType"> {
return llvm::count_if($_type.getShape().take_front(index),
::mlir::ShapedType::isDynamic);
}

bool isRefinementOf(ShapedType source) {
return $_type.getElementType() == source.getElementType() &&
ShapedType::isShapeRefinementOf(source.getShape(), $_type.getShape());
}

bool isGeneralizationOf(ShapedType source) {
return $_type.getElementType() == source.getElementType() &&
ShapedType::isShapeGeneralizationOf(source.getShape(), $_type.getShape());
}
}];
}

Expand Down
27 changes: 27 additions & 0 deletions mlir/include/mlir/IR/BuiltinTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -837,6 +837,33 @@ def Builtin_RankedTensor : Builtin_Type<"RankedTensor", "tensor", [
using ShapedType::Trait<RankedTensorType>::getDimSize;
using ShapedType::Trait<RankedTensorType>::getDynamicDimIndex;

/// Return whether this type is a refinement of `source` with
/// respect to only the shape, meaning they ave the same element type
/// and the shape of this type is the same as source, except
/// zero or more dynamic extents from source have been replaced with
/// static extents.
/// This method is conservative with respect to the encoding. If the
/// encodings are not the same, then false is returned.
bool isRefinementOf(RankedTensorType source) {

return getEncoding() == source.getEncoding() &&
ShapedType::Trait<RankedTensorType>::isRefinementOf(
llvm::cast<ShapedType>(source));
}

/// Return whether this type is a generalization of `source` with
/// respect to only the shape, meaning they have the same element
/// type and the shape of this type is the same as source, except
/// zero or more static extents have been replaced with unknown
/// extents.
/// This method is conservative with respect to the encoding. If the
/// encodings are not the same, then false is returned.
bool isGeneralizationOf(RankedTensorType source) {
return getEncoding() == source.getEncoding() &&
ShapedType::Trait<RankedTensorType>::isGeneralizationOf(
llvm::cast<ShapedType>(source));
}

/// This is a builder type that keeps local references to arguments.
/// Arguments that are passed into the builder must outlive the builder.
class Builder;
Expand Down
22 changes: 20 additions & 2 deletions mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2674,10 +2674,28 @@ FailureOr<SmallVector<Value>> SoftmaxOp::decomposeOperation(OpBuilder &b) {
// LinalgDialect
//===----------------------------------------------------------------------===//

namespace {
struct LinalgAbsorbTensorCastProducersPattern
: public OpInterfaceRewritePattern<linalg::LinalgOp> {
using OpInterfaceRewritePattern<linalg::LinalgOp>::OpInterfaceRewritePattern;

LogicalResult matchAndRewrite(LinalgOp op,
PatternRewriter &rewriter) const override {
DestinationStyleOpInterface dpsOp =
llvm::dyn_cast<DestinationStyleOpInterface>(op.getOperation());
if (!dpsOp)
return failure();
return foldCastProducers<tensor::CastOp>(dpsOp, rewriter);
}
};
} // namespace

void LinalgDialect::getCanonicalizationPatterns(
RewritePatternSet &results) const {
results.add<EraseDeadLinalgOp, FoldTensorCastConsumerOp,
InferStaticShapeOfOperands>(getContext());
results
.add<EraseDeadLinalgOp, FoldTensorCastConsumerOp,
InferStaticShapeOfOperands, LinalgAbsorbTensorCastProducersPattern>(
getContext());
}

Operation *LinalgDialect::materializeConstant(OpBuilder &builder,
Expand Down
32 changes: 28 additions & 4 deletions mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributeInterfaces.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
Expand Down Expand Up @@ -1369,6 +1370,11 @@ OpFoldResult InsertOp::fold(FoldAdaptor adaptor) {
return {};
}

void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<FoldTensorCastIntoConsumerPattern<InsertOp, CastOp>>(context);
}

//===----------------------------------------------------------------------===//
// GenerateOp
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -2413,7 +2419,9 @@ void ExtractSliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
results.add<
OpWithOffsetSizesAndStridesConstantArgumentFolder<
ExtractSliceOp, SliceReturnTypeCanonicalizer, SliceCanonicalizer>,
ExtractSliceOpCastFolder>(context);
ExtractSliceOpCastFolder,
FoldTensorCastIntoConsumerPattern<tensor::ExtractSliceOp,
tensor::CastOp>>(context);
}

//
Expand Down Expand Up @@ -4154,6 +4162,15 @@ static bool inferStaticShape(PackOp packOp, SmallVectorImpl<int64_t> &srcShape,
}

LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
// pack(cast(x)) -> pack(x)
if (packOp.getSource().getDefiningOp<tensor::CastOp>() ||
packOp.getDest().getDefiningOp<tensor::CastOp>()) {
if (succeeded(foldCastProducers<CastOp>(
cast<DestinationStyleOpInterface>(packOp.getOperation()),
rewriter)))
return success();
}

// Fold an unpack(pack(x)) to x.
if (auto unPackOp = packOp.getSource().getDefiningOp<UnPackOp>()) {
if (unPackOp.getSourceType() != packOp.getDestType())
Expand Down Expand Up @@ -4388,6 +4405,15 @@ static bool inferStaticShape(UnPackOp op, SmallVectorImpl<int64_t> &srcShape,

LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
PatternRewriter &rewriter) {
// pack(cast(x)) -> pack(x)
if (unPackOp.getSource().getDefiningOp<tensor::CastOp>() ||
unPackOp.getDest().getDefiningOp<tensor::CastOp>()) {
if (succeeded(foldCastProducers<CastOp>(
cast<DestinationStyleOpInterface>(unPackOp.getOperation()),
rewriter)))
return success();
}

/// pack(unpack(x)) -> x
if (PackOp packOp = unPackOp.getSource().getDefiningOp<tensor::PackOp>()) {
if (packOp.getDestType() != unPackOp.getSourceType())
Expand Down Expand Up @@ -4533,9 +4559,7 @@ struct FoldTensorCastProducerOp
//===----------------------------------------------------------------------===//

void TensorDialect::getCanonicalizationPatterns(
RewritePatternSet &results) const {
results.add<FoldTensorCastProducerOp>(getContext());
}
RewritePatternSet &results) const {}

//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
Expand Down
53 changes: 53 additions & 0 deletions mlir/lib/Dialect/Utils/StructuredOpsUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -228,3 +228,56 @@ mlir::getPrunedAttributeList(Operation *op, ArrayRef<StringRef> elidedAttrs) {
}
return attrs;
}

LogicalResult mlir::foldCastProducers(
RewriterBase &rewriter, DestinationStyleOpInterface op,
llvm::function_ref<bool(Operation *)> isPreservingCast,
llvm::function_ref<Value(RewriterBase &rewriter, Type originalType,
Value replacement)>
createCast) {

auto canFoldIntoConsumerOp = [&isPreservingCast](Operation *castOp) {
return castOp && isPreservingCast(castOp);
};

// If no operand comes from a tensor::CastOp and can be folded then fail.
bool hasTensorCastOperand =
llvm::any_of(op->getOpOperands(), [&](OpOperand &opOperand) {
if (llvm::isa<BlockArgument>(opOperand.get()))
return false;
Operation *castOp = opOperand.get().getDefiningOp();
return castOp && canFoldIntoConsumerOp(castOp);
});
if (!hasTensorCastOperand)
return failure();

SmallVector<Type, 4> newResultTypes;
newResultTypes.reserve(op->getNumResults());
SmallVector<Value, 4> newOperands;
newOperands.reserve(op->getNumOperands());
for (OpOperand &opOperand : op->getOpOperands()) {
Operation *tensorCastOp = opOperand.get().getDefiningOp();
bool fold = canFoldIntoConsumerOp(tensorCastOp);
newOperands.push_back(fold ? tensorCastOp->getOperand(0) : opOperand.get());
if (op.isDpsInit(&opOperand) &&
!llvm::isa<MemRefType>(newOperands.back().getType()))
newResultTypes.push_back(newOperands.back().getType());
}

// Clone op.
Operation *newOp = clone(rewriter, op, newResultTypes, newOperands);
SmallVector<Value, 4> replacements;
replacements.reserve(newOp->getNumResults());
for (auto [oldResult, newResult] :
llvm::zip(op->getResults(), newOp->getResults())) {
if (newResult.getType() != oldResult.getType()) {
Value resultCast = createCast(rewriter, oldResult.getType(), newResult);
replacements.push_back(resultCast);
} else {
replacements.push_back(newResult);
}
}
rewriter.replaceOp(op, replacements);

return success();
}
10 changes: 8 additions & 2 deletions mlir/lib/Dialect/Vector/IR/VectorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4184,7 +4184,10 @@ struct TransferReadAfterWriteToBroadcast

void TransferReadOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<TransferReadAfterWriteToBroadcast>(context);
results
.add<TransferReadAfterWriteToBroadcast,
FoldTensorCastIntoConsumerPattern<TransferReadOp, tensor::CastOp>>(
context);
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -4636,7 +4639,10 @@ struct SwapExtractSliceOfTransferWrite

void TransferWriteOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<FoldWaw, SwapExtractSliceOfTransferWrite>(context);
results
.add<FoldWaw, SwapExtractSliceOfTransferWrite,
FoldTensorCastIntoConsumerPattern<TransferWriteOp, tensor::CastOp>>(
context);
}

//===----------------------------------------------------------------------===//
Expand Down
24 changes: 24 additions & 0 deletions mlir/lib/IR/BuiltinTypeInterfaces.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,27 @@ int64_t ShapedType::getNumElements(ArrayRef<int64_t> shape) {
}
return num;
}

bool ShapedType::isShapeRefinementOf(ArrayRef<int64_t> source,
ArrayRef<int64_t> target) {
if (source.size() != target.size())
return false;
for (auto [srcDim, tgtDim] : llvm::zip_equal(source, target)) {
// If the source dimension is dynamic, then the target dimension can be
// dynamic or static.
if (isDynamic(srcDim))
continue;
// Static source dim and dynamic result dim -> not a refinement.
if (isDynamic(tgtDim))
return false;
// Static source dim != static result dim -> not a refinement.
if (srcDim != tgtDim)
return false;
}
return true;
}

bool ShapedType::isShapeGeneralizationOf(ArrayRef<int64_t> source,
ArrayRef<int64_t> target) {
return isShapeRefinementOf(target, source);
}
16 changes: 16 additions & 0 deletions mlir/test/Dialect/Bufferization/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -388,3 +388,19 @@ func.func @negative_input() -> tensor<?x?x?xf16> {
%11 = bufferization.alloc_tensor(%c10, %idx-3, %idx27) : tensor<?x?x?xf16>
return %11 : tensor<?x?x?xf16>
}

// -----

func.func @materialize_in_destination_tensor_cast(%arg0: tensor<4xf32>, %arg1: index) -> tensor<?xf32> {
%0 = bufferization.alloc_tensor(%arg1) : tensor<?xf32>
%1 = tensor.cast %arg0 : tensor<4xf32> to tensor<?xf32>
%2 = bufferization.materialize_in_destination %1 in %0 : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
return %2 : tensor<?xf32>
}

// Check that a `tensor.cast` producer is not absorbed.

// CHECK-LABEL: func.func @materialize_in_destination_tensor_cast
// CHECK: tensor.cast
// CHECK: bufferization.materialize_in_destination
// CHECK-SAME: : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>