Skip to content

Commit 809aa65

Browse files
[MLIR][Bufferization] Fix DPS op canonicalizer with tensor.cast`
Attempts to address a bug pointed out in #91265 by moving the FoldTensorCastProducerOp canonicalizer definition upward into the MLIRDialectUtils library. Since the MLIRDialectUtils can't depend on any dialect, the canonicalizer had to change slightly, and a templated version is introduced. Then, we need to add this canonicalization routine where it was used before, except for places where it is incorrect as pointed out in the bug. Based on cursory inspection of the TableGen definitions, only `bufferization.materialize_in_destination` should *not* have the canonicalizer, but existing tests passed if the canonicalizer as only added for `tensor.pack|unpack|extract_slice` and the LinalgOp interface.
1 parent 94204f5 commit 809aa65

File tree

10 files changed

+268
-8
lines changed

10 files changed

+268
-8
lines changed

mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -818,6 +818,7 @@ def Tensor_InsertOp : Tensor_Op<"insert", [
818818
MutableOperandRange getDpsInitsMutable() { return getDestMutable(); }
819819
}];
820820

821+
let hasCanonicalizer = 1;
821822
let hasFolder = 1;
822823
let hasVerifier = 1;
823824
}

mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,9 @@
2020
#include "mlir/IR/AffineMap.h"
2121
#include "mlir/IR/BuiltinAttributes.h"
2222
#include "mlir/IR/Location.h"
23+
#include "mlir/IR/PatternMatch.h"
2324
#include "mlir/IR/TypeRange.h"
25+
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
2426
#include "mlir/Support/LLVM.h"
2527

2628
// Pull in all enum type definitions and utility function declarations.
@@ -158,6 +160,76 @@ Operation *cloneWithoutRegions(OpBuilder &b, Operation *op,
158160
SmallVector<NamedAttribute>
159161
getPrunedAttributeList(Operation *op, ArrayRef<StringRef> elidedAttrs);
160162

163+
/// Folds cast-like operations into a consuming DestinationStyleOpInterface op
164+
/// if `isPreservingCast` is true. If the cast appears on a 'DPS-init operand',
165+
/// then the tied result type is updated as well to the type of the cast source,
166+
/// and a new cast must be inserted on the new op's result. `createCast` is used
167+
/// to build such required cast ops.
168+
///
169+
/// ### Example
170+
/// If the `isPreservingCast` returns true if the cast is a "generalizing"
171+
/// `tensor.cast`, then this function would be have as follows:
172+
///
173+
/// ```mlir
174+
/// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
175+
/// %2 = dps_op %1 ... : tensor<?x?xf32> ...
176+
/// ```
177+
///
178+
/// folds into:
179+
///
180+
/// ```mlir
181+
/// %2 = dps_op %0 ... : tensor<8x16xf32> ...
182+
/// ```
183+
LogicalResult foldCastProducers(
184+
RewriterBase &rewriter, DestinationStyleOpInterface consumerOp,
185+
llvm::function_ref<bool(Operation *)> isPreservingCast,
186+
llvm::function_ref<Value(RewriterBase &rewriter, Type originalType,
187+
Value replacement)>
188+
createCast);
189+
190+
/// Folds `tensor.cast` ops into a consuming DestinationStyleOpInterface op
191+
/// if the casts make their operands less static. See also isPreservingCast
192+
/// above.
193+
template <typename CastOpType>
194+
LogicalResult foldCastProducers(DestinationStyleOpInterface op,
195+
RewriterBase &rewriter) {
196+
return foldCastProducers(
197+
rewriter, op,
198+
[](Operation *castOp) -> bool {
199+
auto concreteCast = dyn_cast<CastOpType>(castOp);
200+
if (!concreteCast)
201+
return false;
202+
RankedTensorType resultType =
203+
dyn_cast<RankedTensorType>(concreteCast.getType());
204+
RankedTensorType sourceType =
205+
dyn_cast<RankedTensorType>(concreteCast->getOperand(0).getType());
206+
if (!resultType || !sourceType)
207+
return false;
208+
return resultType.isGeneralizationOf(sourceType);
209+
},
210+
[](RewriterBase &rewriter, Type resultType, Value operand) -> Value {
211+
return rewriter.create<CastOpType>(operand.getLoc(), resultType,
212+
operand);
213+
});
214+
}
215+
216+
/// A generic pattern for an Operation type that implements
217+
/// DestinationStyleOpInterface, allowing for absorbing cast-like operations
218+
/// that are producers of operands.
219+
template <typename OpType, typename CastOpType>
220+
struct FoldTensorCastIntoConsumerPattern : public OpRewritePattern<OpType> {
221+
using OpRewritePattern<OpType>::OpRewritePattern;
222+
223+
LogicalResult matchAndRewrite(OpType op,
224+
PatternRewriter &rewriter) const override {
225+
DestinationStyleOpInterface dpsOp =
226+
llvm::dyn_cast<DestinationStyleOpInterface>(op.getOperation());
227+
if (!dpsOp)
228+
return failure();
229+
return foldCastProducers<CastOpType>(dpsOp, rewriter);
230+
}
231+
};
232+
161233
} // namespace mlir
162234

163235
#endif // MLIR_DIALECT_UTILS_STRUCTUREDOPSUTILS_H

mlir/include/mlir/IR/BuiltinTypeInterfaces.td

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,15 @@ def ShapedTypeInterface : TypeInterface<"ShapedType"> {
116116
auto clone(::llvm::ArrayRef<int64_t> shape) {
117117
return cloneWith(shape, getElementType());
118118
}
119+
120+
/// Return whether the target shape is a refinement of the source shape.
121+
static bool isShapeRefinementOf(
122+
ArrayRef<int64_t> source, ArrayRef<int64_t> target);
123+
124+
/// Return whether the target shape is a generalization of the source
125+
/// shape.
126+
static bool isShapeGeneralizationOf(
127+
ArrayRef<int64_t> source, ArrayRef<int64_t> target);
119128
}];
120129

121130
let extraSharedClassDeclaration = [{
@@ -185,6 +194,16 @@ def ShapedTypeInterface : TypeInterface<"ShapedType"> {
185194
return llvm::count_if($_type.getShape().take_front(index),
186195
::mlir::ShapedType::isDynamic);
187196
}
197+
198+
bool isRefinementOf(ShapedType source) {
199+
return $_type.getElementType() == source.getElementType() &&
200+
ShapedType::isShapeRefinementOf(source.getShape(), $_type.getShape());
201+
}
202+
203+
bool isGeneralizationOf(ShapedType source) {
204+
return $_type.getElementType() == source.getElementType() &&
205+
ShapedType::isShapeGeneralizationOf(source.getShape(), $_type.getShape());
206+
}
188207
}];
189208
}
190209

mlir/include/mlir/IR/BuiltinTypes.td

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -837,6 +837,33 @@ def Builtin_RankedTensor : Builtin_Type<"RankedTensor", "tensor", [
837837
using ShapedType::Trait<RankedTensorType>::getDimSize;
838838
using ShapedType::Trait<RankedTensorType>::getDynamicDimIndex;
839839

840+
/// Return whether this type is a refinement of `source` with
841+
/// respect to only the shape, meaning they ave the same element type
842+
/// and the shape of this type is the same as source, except
843+
/// zero or more dynamic extents from source have been replaced with
844+
/// static extents.
845+
/// This method is conservative with respect to the encoding. If the
846+
/// encodings are not the same, then false is returned.
847+
bool isRefinementOf(RankedTensorType source) {
848+
849+
return getEncoding() == source.getEncoding() &&
850+
ShapedType::Trait<RankedTensorType>::isRefinementOf(
851+
llvm::cast<ShapedType>(source));
852+
}
853+
854+
/// Return whether this type is a generalization of `source` with
855+
/// respect to only the shape, meaning they have the same element
856+
/// type and the shape of this type is the same as source, except
857+
/// zero or more static extents have been replaced with unknown
858+
/// extents.
859+
/// This method is conservative with respect to the encoding. If the
860+
/// encodings are not the same, then false is returned.
861+
bool isGeneralizationOf(RankedTensorType source) {
862+
return getEncoding() == source.getEncoding() &&
863+
ShapedType::Trait<RankedTensorType>::isGeneralizationOf(
864+
llvm::cast<ShapedType>(source));
865+
}
866+
840867
/// This is a builder type that keeps local references to arguments.
841868
/// Arguments that are passed into the builder must outlive the builder.
842869
class Builder;

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2674,10 +2674,28 @@ FailureOr<SmallVector<Value>> SoftmaxOp::decomposeOperation(OpBuilder &b) {
26742674
// LinalgDialect
26752675
//===----------------------------------------------------------------------===//
26762676

2677+
namespace {
2678+
struct LinalgAbsorbTensorCastProducersPattern
2679+
: public OpInterfaceRewritePattern<linalg::LinalgOp> {
2680+
using OpInterfaceRewritePattern<linalg::LinalgOp>::OpInterfaceRewritePattern;
2681+
2682+
LogicalResult matchAndRewrite(LinalgOp op,
2683+
PatternRewriter &rewriter) const override {
2684+
DestinationStyleOpInterface dpsOp =
2685+
llvm::dyn_cast<DestinationStyleOpInterface>(op.getOperation());
2686+
if (!dpsOp)
2687+
return failure();
2688+
return foldCastProducers<tensor::CastOp>(dpsOp, rewriter);
2689+
}
2690+
};
2691+
} // namespace
2692+
26772693
void LinalgDialect::getCanonicalizationPatterns(
26782694
RewritePatternSet &results) const {
2679-
results.add<EraseDeadLinalgOp, FoldTensorCastConsumerOp,
2680-
InferStaticShapeOfOperands>(getContext());
2695+
results
2696+
.add<EraseDeadLinalgOp, FoldTensorCastConsumerOp,
2697+
InferStaticShapeOfOperands, LinalgAbsorbTensorCastProducersPattern>(
2698+
getContext());
26812699
}
26822700

26832701
Operation *LinalgDialect::materializeConstant(OpBuilder &builder,

mlir/lib/Dialect/Tensor/IR/TensorOps.cpp

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include "mlir/Dialect/Utils/IndexingUtils.h"
1515
#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
1616
#include "mlir/Dialect/Utils/StaticValueUtils.h"
17+
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
1718
#include "mlir/IR/Builders.h"
1819
#include "mlir/IR/BuiltinAttributeInterfaces.h"
1920
#include "mlir/IR/BuiltinTypeInterfaces.h"
@@ -1369,6 +1370,11 @@ OpFoldResult InsertOp::fold(FoldAdaptor adaptor) {
13691370
return {};
13701371
}
13711372

1373+
void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results,
1374+
MLIRContext *context) {
1375+
results.add<FoldTensorCastIntoConsumerPattern<InsertOp, CastOp>>(context);
1376+
}
1377+
13721378
//===----------------------------------------------------------------------===//
13731379
// GenerateOp
13741380
//===----------------------------------------------------------------------===//
@@ -2413,7 +2419,9 @@ void ExtractSliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
24132419
results.add<
24142420
OpWithOffsetSizesAndStridesConstantArgumentFolder<
24152421
ExtractSliceOp, SliceReturnTypeCanonicalizer, SliceCanonicalizer>,
2416-
ExtractSliceOpCastFolder>(context);
2422+
ExtractSliceOpCastFolder,
2423+
FoldTensorCastIntoConsumerPattern<tensor::ExtractSliceOp,
2424+
tensor::CastOp>>(context);
24172425
}
24182426

24192427
//
@@ -4154,6 +4162,15 @@ static bool inferStaticShape(PackOp packOp, SmallVectorImpl<int64_t> &srcShape,
41544162
}
41554163

41564164
LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
4165+
// pack(cast(x)) -> pack(x)
4166+
if (packOp.getSource().getDefiningOp<tensor::CastOp>() ||
4167+
packOp.getDest().getDefiningOp<tensor::CastOp>()) {
4168+
if (succeeded(foldCastProducers<CastOp>(
4169+
cast<DestinationStyleOpInterface>(packOp.getOperation()),
4170+
rewriter)))
4171+
return success();
4172+
}
4173+
41574174
// Fold an unpack(pack(x)) to x.
41584175
if (auto unPackOp = packOp.getSource().getDefiningOp<UnPackOp>()) {
41594176
if (unPackOp.getSourceType() != packOp.getDestType())
@@ -4388,6 +4405,15 @@ static bool inferStaticShape(UnPackOp op, SmallVectorImpl<int64_t> &srcShape,
43884405

43894406
LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
43904407
PatternRewriter &rewriter) {
4408+
// pack(cast(x)) -> pack(x)
4409+
if (unPackOp.getSource().getDefiningOp<tensor::CastOp>() ||
4410+
unPackOp.getDest().getDefiningOp<tensor::CastOp>()) {
4411+
if (succeeded(foldCastProducers<CastOp>(
4412+
cast<DestinationStyleOpInterface>(unPackOp.getOperation()),
4413+
rewriter)))
4414+
return success();
4415+
}
4416+
43914417
/// pack(unpack(x)) -> x
43924418
if (PackOp packOp = unPackOp.getSource().getDefiningOp<tensor::PackOp>()) {
43934419
if (packOp.getDestType() != unPackOp.getSourceType())
@@ -4533,9 +4559,7 @@ struct FoldTensorCastProducerOp
45334559
//===----------------------------------------------------------------------===//
45344560

45354561
void TensorDialect::getCanonicalizationPatterns(
4536-
RewritePatternSet &results) const {
4537-
results.add<FoldTensorCastProducerOp>(getContext());
4538-
}
4562+
RewritePatternSet &results) const {}
45394563

45404564
//===----------------------------------------------------------------------===//
45414565
// TableGen'd op method definitions

mlir/lib/Dialect/Utils/StructuredOpsUtils.cpp

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,3 +228,56 @@ mlir::getPrunedAttributeList(Operation *op, ArrayRef<StringRef> elidedAttrs) {
228228
}
229229
return attrs;
230230
}
231+
232+
LogicalResult mlir::foldCastProducers(
233+
RewriterBase &rewriter, DestinationStyleOpInterface op,
234+
llvm::function_ref<bool(Operation *)> isPreservingCast,
235+
llvm::function_ref<Value(RewriterBase &rewriter, Type originalType,
236+
Value replacement)>
237+
createCast) {
238+
239+
auto canFoldIntoConsumerOp = [&isPreservingCast](Operation *castOp) {
240+
return castOp && isPreservingCast(castOp);
241+
};
242+
243+
// If no operand comes from a tensor::CastOp and can be folded then fail.
244+
bool hasTensorCastOperand =
245+
llvm::any_of(op->getOpOperands(), [&](OpOperand &opOperand) {
246+
if (llvm::isa<BlockArgument>(opOperand.get()))
247+
return false;
248+
Operation *castOp = opOperand.get().getDefiningOp();
249+
return castOp && canFoldIntoConsumerOp(castOp);
250+
});
251+
if (!hasTensorCastOperand)
252+
return failure();
253+
254+
SmallVector<Type, 4> newResultTypes;
255+
newResultTypes.reserve(op->getNumResults());
256+
SmallVector<Value, 4> newOperands;
257+
newOperands.reserve(op->getNumOperands());
258+
for (OpOperand &opOperand : op->getOpOperands()) {
259+
Operation *tensorCastOp = opOperand.get().getDefiningOp();
260+
bool fold = canFoldIntoConsumerOp(tensorCastOp);
261+
newOperands.push_back(fold ? tensorCastOp->getOperand(0) : opOperand.get());
262+
if (op.isDpsInit(&opOperand) &&
263+
!llvm::isa<MemRefType>(newOperands.back().getType()))
264+
newResultTypes.push_back(newOperands.back().getType());
265+
}
266+
267+
// Clone op.
268+
Operation *newOp = clone(rewriter, op, newResultTypes, newOperands);
269+
SmallVector<Value, 4> replacements;
270+
replacements.reserve(newOp->getNumResults());
271+
for (auto [oldResult, newResult] :
272+
llvm::zip(op->getResults(), newOp->getResults())) {
273+
if (newResult.getType() != oldResult.getType()) {
274+
Value resultCast = createCast(rewriter, oldResult.getType(), newResult);
275+
replacements.push_back(resultCast);
276+
} else {
277+
replacements.push_back(newResult);
278+
}
279+
}
280+
rewriter.replaceOp(op, replacements);
281+
282+
return success();
283+
}

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4184,7 +4184,10 @@ struct TransferReadAfterWriteToBroadcast
41844184

41854185
void TransferReadOp::getCanonicalizationPatterns(RewritePatternSet &results,
41864186
MLIRContext *context) {
4187-
results.add<TransferReadAfterWriteToBroadcast>(context);
4187+
results
4188+
.add<TransferReadAfterWriteToBroadcast,
4189+
FoldTensorCastIntoConsumerPattern<TransferReadOp, tensor::CastOp>>(
4190+
context);
41884191
}
41894192

41904193
//===----------------------------------------------------------------------===//
@@ -4636,7 +4639,10 @@ struct SwapExtractSliceOfTransferWrite
46364639

46374640
void TransferWriteOp::getCanonicalizationPatterns(RewritePatternSet &results,
46384641
MLIRContext *context) {
4639-
results.add<FoldWaw, SwapExtractSliceOfTransferWrite>(context);
4642+
results
4643+
.add<FoldWaw, SwapExtractSliceOfTransferWrite,
4644+
FoldTensorCastIntoConsumerPattern<TransferWriteOp, tensor::CastOp>>(
4645+
context);
46404646
}
46414647

46424648
//===----------------------------------------------------------------------===//

mlir/lib/IR/BuiltinTypeInterfaces.cpp

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,3 +33,27 @@ int64_t ShapedType::getNumElements(ArrayRef<int64_t> shape) {
3333
}
3434
return num;
3535
}
36+
37+
bool ShapedType::isShapeRefinementOf(ArrayRef<int64_t> source,
38+
ArrayRef<int64_t> target) {
39+
if (source.size() != target.size())
40+
return false;
41+
for (auto [srcDim, tgtDim] : llvm::zip_equal(source, target)) {
42+
// If the source dimension is dynamic, then the target dimension can be
43+
// dynamic or static.
44+
if (isDynamic(srcDim))
45+
continue;
46+
// Static source dim and dynamic result dim -> not a refinement.
47+
if (isDynamic(tgtDim))
48+
return false;
49+
// Static source dim != static result dim -> not a refinement.
50+
if (srcDim != tgtDim)
51+
return false;
52+
}
53+
return true;
54+
}
55+
56+
bool ShapedType::isShapeGeneralizationOf(ArrayRef<int64_t> source,
57+
ArrayRef<int64_t> target) {
58+
return isShapeRefinementOf(target, source);
59+
}

mlir/test/Dialect/Bufferization/canonicalize.mlir

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -388,3 +388,19 @@ func.func @negative_input() -> tensor<?x?x?xf16> {
388388
%11 = bufferization.alloc_tensor(%c10, %idx-3, %idx27) : tensor<?x?x?xf16>
389389
return %11 : tensor<?x?x?xf16>
390390
}
391+
392+
// -----
393+
394+
func.func @materialize_in_destination_tensor_cast(%arg0: tensor<4xf32>, %arg1: index) -> tensor<?xf32> {
395+
%0 = bufferization.alloc_tensor(%arg1) : tensor<?xf32>
396+
%1 = tensor.cast %arg0 : tensor<4xf32> to tensor<?xf32>
397+
%2 = bufferization.materialize_in_destination %1 in %0 : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
398+
return %2 : tensor<?xf32>
399+
}
400+
401+
// Check that a `tensor.cast` producer is not absorbed.
402+
403+
// CHECK-LABEL: func.func @materialize_in_destination_tensor_cast
404+
// CHECK: tensor.cast
405+
// CHECK: bufferization.materialize_in_destination
406+
// CHECK-SAME: : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>

0 commit comments

Comments
 (0)