diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td index 0bf22928f6900..c342f25fe61a9 100644 --- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td +++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td @@ -565,6 +565,7 @@ def MemRef_CastOp : MemRef_Op<"cast", [ }]; let hasFolder = 1; + let hasCanonicalizer = 1; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp index 1c21a2f270da6..e94db0ccb11de 100644 --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -797,6 +797,35 @@ OpFoldResult CastOp::fold(FoldAdaptor adaptor) { return succeeded(foldMemRefCast(*this)) ? getResult() : Value(); } +namespace { +struct HoistCastPos : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(CastOp castOp, + PatternRewriter &rewriter) const override { + if (auto *defineOp = castOp.getSource().getDefiningOp()) { + if (defineOp->getBlock() != castOp->getBlock()) { + rewriter.moveOpAfter(castOp.getOperation(), defineOp); + return success(); + } + return failure(); + } else { + auto argument = cast(castOp.getSource()); + if (argument.getOwner() != castOp->getBlock()) { + rewriter.moveOpBefore(castOp.getOperation(), + &argument.getOwner()->front()); + return success(); + } + return failure(); + } + } +}; +} // namespace + +void CastOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.add(context); +} + FailureOr>> CastOp::bubbleDownCasts(OpBuilder &builder) { return bubbleDownCastsPassthroughOpImpl(*this, builder, getSourceMutable()); diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir index 313090272ef90..e435615cc8e26 100644 --- a/mlir/test/Dialect/MemRef/canonicalize.mlir +++ b/mlir/test/Dialect/MemRef/canonicalize.mlir @@ -1367,3 +1367,44 @@ func.func @non_fold_view_same_source_res_types(%0: memref, %arg0 : index) %res = memref.view %0[%c0][%arg0] : memref to memref return %res : memref } + +// ----- + +// CHECK-LABEL: func @hoist_cast_pos +// CHECK-SAME: %[[ARG0:.*]]: memref<10xf32>, +// CHECK-SAME: %[[ARG1:.*]]: i1 +func.func @hoist_cast_pos(%arg: memref<10xf32>, %arg1: i1) -> (memref) { + // CHECK: %[[CAST_0:.*]] = memref.cast %[[ARG0]] + // CHECK: %[[CAST_1:.*]] = memref.cast %[[ARG0]] + // CHECK-NEXT: cf.cond_br %[[ARG1]] + cf.cond_br %arg1, ^bb1, ^bb2 +^bb1: + %cast = memref.cast %arg : memref<10xf32> to memref + // CHECK: return %[[CAST_1]] + return %cast : memref +^bb2: + %cast1 = memref.cast %arg : memref<10xf32> to memref + // CHECK: return %[[CAST_0]] + return %cast1 : memref +} + +// ----- + +// CHECK-LABEL: func.func @hoist_cast_pos_alloc +// CHECK-SAME: %[[ARG0:.*]]: i1 +func.func @hoist_cast_pos_alloc(%arg: i1) -> (memref) { + // CHECK: %[[ALLOC_0:.*]] = memref.alloc() + // CHECK: %[[CAST_0:.*]] = memref.cast %[[ALLOC_0]] + // CHECK: %[[CAST_1:.*]] = memref.cast %[[ALLOC_0]] + // CHECK-NEXT: cf.cond_br %[[ARG0]] + %alloc = memref.alloc() : memref<10xf32> + cf.cond_br %arg, ^bb1, ^bb2 +^bb1: + %cast = memref.cast %alloc : memref<10xf32> to memref + // CHECK: return %[[CAST_1]] + return %cast : memref +^bb2: + %cast1 = memref.cast %alloc : memref<10xf32> to memref + // CHECK: return %[[CAST_0]] + return %cast1 : memref +} diff --git a/mlir/test/Dialect/SCF/one-shot-bufferize.mlir b/mlir/test/Dialect/SCF/one-shot-bufferize.mlir index af09dc865e2de..d1c1f1780e353 100644 --- a/mlir/test/Dialect/SCF/one-shot-bufferize.mlir +++ b/mlir/test/Dialect/SCF/one-shot-bufferize.mlir @@ -922,13 +922,13 @@ func.func @elide_copy_of_non_writing_scf_if(%c: i1, %p1: index, %p2: index, %f: // CHECK-SAME: %[[pred:.*]]: index, %[[b:.*]]: memref<{{.*}}>, %[[c:.*]]: memref<{{.*}}>) -> memref<{{.*}}> func.func @index_switch(%pred: index, %b: tensor<5xf32>, %c: tensor<5xf32>) -> tensor<5xf32> { // Throw in a tensor that bufferizes to a different layout map. - // CHECK: %[[a:.*]] = memref.alloc() {{.*}} : memref<5xf32> + // CHECK: %[[a:.*]] = memref.alloc() {{.*}} : memref<5xf32> + // CHECK: %[[cast:.*]] = memref.cast %[[a]] : memref<5xf32> to memref<5xf32, strided<[?], offset: ?>> %a = bufferization.alloc_tensor() : tensor<5xf32> // CHECK: %[[r:.*]] = scf.index_switch %[[pred]] -> memref<5xf32, strided<[?], offset: ?>> %0 = scf.index_switch %pred -> tensor<5xf32> // CHECK: case 2 { - // CHECK: %[[cast:.*]] = memref.cast %[[a]] : memref<5xf32> to memref<5xf32, strided<[?], offset: ?>> // CHECK: scf.yield %[[cast]] case 2 { scf.yield %a: tensor<5xf32>