From b025178c6406849b16c753fe0fb9f3e9920606ce Mon Sep 17 00:00:00 2001 From: linuxlonelyeagle <2020382038@qq.com> Date: Mon, 17 Nov 2025 09:52:34 +0000 Subject: [PATCH 1/3] add foldUseDominateCast to castOp. --- mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 28 +++++++++++++++++++++- mlir/test/Dialect/MemRef/canonicalize.mlir | 19 +++++++++++++++ 2 files changed, 46 insertions(+), 1 deletion(-) diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp index 1c21a2f270da6..aafd908c7af7e 100644 --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -13,10 +13,12 @@ #include "mlir/IR/AffineMap.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Dominance.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/OpDefinition.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" +#include "mlir/Interfaces/FunctionInterfaces.h" #include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Interfaces/Utils/InferIntRangeCommon.h" @@ -793,8 +795,32 @@ bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { return false; } +static OpFoldResult foldUseDominateCast(CastOp castOp) { + auto funcOp = castOp->getParentOfType(); + if (!funcOp) + return {}; + auto castOps = castOp->getBlock()->getOps(); + CastOp dominateCastOp = castOp; + SmallVector ops(castOps); + mlir::DominanceInfo dominanceInfo(castOp); + for (auto it : castOps) { + if (it.getSource() == dominateCastOp.getSource() && + it.getDest().getType() == dominateCastOp.getDest().getType() && + dominanceInfo.dominates(it.getOperation(), + dominateCastOp.getOperation())) { + dominateCastOp = it; + } + } + return dominateCastOp == castOp ? Value() : dominateCastOp.getResult(); +} + OpFoldResult CastOp::fold(FoldAdaptor adaptor) { - return succeeded(foldMemRefCast(*this)) ? getResult() : Value(); + OpFoldResult result; + if (OpFoldResult value = foldUseDominateCast(*this)) + result = value; + if (succeeded(foldMemRefCast(*this))) + result = getResult(); + return result; } FailureOr>> diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir index 313090272ef90..3638b8d4ac701 100644 --- a/mlir/test/Dialect/MemRef/canonicalize.mlir +++ b/mlir/test/Dialect/MemRef/canonicalize.mlir @@ -1367,3 +1367,22 @@ func.func @non_fold_view_same_source_res_types(%0: memref, %arg0 : index) %res = memref.view %0[%c0][%arg0] : memref to memref return %res : memref } + +// ----- + +func.func @fold_use_dominate_cast_foo(%arg0: memref>) { + return +} + +// CHECK-LABEL: func @fold_use_dominate_cast( +// CHECK-SAME: %[[ARG0:.*]]: memref) +func.func @fold_use_dominate_cast(%arg: memref) { + // CHECK: %[[CAST_0:.*]] = memref.cast %[[ARG0]] + %cast0 = memref.cast %arg : memref to memref> + %cast1 = memref.cast %arg : memref to memref> + // CHECK: call @fold_use_dominate_cast_foo(%[[CAST_0]]) + call @fold_use_dominate_cast_foo(%cast0) : (memref>) -> () + // CHECK: call @fold_use_dominate_cast_foo(%[[CAST_0]]) + call @fold_use_dominate_cast_foo(%cast1) : (memref>) -> () + return +} From f3127f08d48266113c1a4f84cbe03c93944af194 Mon Sep 17 00:00:00 2001 From: linuxlonelyeagle <2020382038@qq.com> Date: Tue, 18 Nov 2025 06:42:14 +0000 Subject: [PATCH 2/3] add HoistCastPos pattern. --- .../mlir/Dialect/MemRef/IR/MemRefOps.td | 1 + mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 51 ++++++++++-------- mlir/test/Dialect/MemRef/canonicalize.mlir | 52 +++++++++++++------ mlir/test/Dialect/SCF/one-shot-bufferize.mlir | 4 +- 4 files changed, 68 insertions(+), 40 deletions(-) diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td index 0bf22928f6900..c342f25fe61a9 100644 --- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td +++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td @@ -565,6 +565,7 @@ def MemRef_CastOp : MemRef_Op<"cast", [ }]; let hasFolder = 1; + let hasCanonicalizer = 1; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp index aafd908c7af7e..b489f71b775e0 100644 --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -795,32 +795,37 @@ bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { return false; } -static OpFoldResult foldUseDominateCast(CastOp castOp) { - auto funcOp = castOp->getParentOfType(); - if (!funcOp) - return {}; - auto castOps = castOp->getBlock()->getOps(); - CastOp dominateCastOp = castOp; - SmallVector ops(castOps); - mlir::DominanceInfo dominanceInfo(castOp); - for (auto it : castOps) { - if (it.getSource() == dominateCastOp.getSource() && - it.getDest().getType() == dominateCastOp.getDest().getType() && - dominanceInfo.dominates(it.getOperation(), - dominateCastOp.getOperation())) { - dominateCastOp = it; +OpFoldResult CastOp::fold(FoldAdaptor adaptor) { + return succeeded(foldMemRefCast(*this)) ? getResult() : Value(); +} + +namespace { +struct HoistCastPos : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(CastOp castOp, + PatternRewriter &rewriter) const override { + if (auto *defineOp = castOp.getSource().getDefiningOp()) { + if (defineOp->getBlock() != castOp->getBlock()) { + rewriter.moveOpAfter(castOp.getOperation(), defineOp); + return success(); + } + return failure(); + } else { + auto argument = cast(castOp.getSource()); + if (argument.getOwner() != castOp->getBlock()) { + rewriter.moveOpBefore(castOp.getOperation(), + &argument.getOwner()->front()); + return success(); + } + return failure(); } } - return dominateCastOp == castOp ? Value() : dominateCastOp.getResult(); -} +}; +} // namespace -OpFoldResult CastOp::fold(FoldAdaptor adaptor) { - OpFoldResult result; - if (OpFoldResult value = foldUseDominateCast(*this)) - result = value; - if (succeeded(foldMemRefCast(*this))) - result = getResult(); - return result; +void CastOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.add(context); } FailureOr>> diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir index 3638b8d4ac701..e435615cc8e26 100644 --- a/mlir/test/Dialect/MemRef/canonicalize.mlir +++ b/mlir/test/Dialect/MemRef/canonicalize.mlir @@ -1370,19 +1370,41 @@ func.func @non_fold_view_same_source_res_types(%0: memref, %arg0 : index) // ----- -func.func @fold_use_dominate_cast_foo(%arg0: memref>) { - return -} - -// CHECK-LABEL: func @fold_use_dominate_cast( -// CHECK-SAME: %[[ARG0:.*]]: memref) -func.func @fold_use_dominate_cast(%arg: memref) { - // CHECK: %[[CAST_0:.*]] = memref.cast %[[ARG0]] - %cast0 = memref.cast %arg : memref to memref> - %cast1 = memref.cast %arg : memref to memref> - // CHECK: call @fold_use_dominate_cast_foo(%[[CAST_0]]) - call @fold_use_dominate_cast_foo(%cast0) : (memref>) -> () - // CHECK: call @fold_use_dominate_cast_foo(%[[CAST_0]]) - call @fold_use_dominate_cast_foo(%cast1) : (memref>) -> () - return +// CHECK-LABEL: func @hoist_cast_pos +// CHECK-SAME: %[[ARG0:.*]]: memref<10xf32>, +// CHECK-SAME: %[[ARG1:.*]]: i1 +func.func @hoist_cast_pos(%arg: memref<10xf32>, %arg1: i1) -> (memref) { + // CHECK: %[[CAST_0:.*]] = memref.cast %[[ARG0]] + // CHECK: %[[CAST_1:.*]] = memref.cast %[[ARG0]] + // CHECK-NEXT: cf.cond_br %[[ARG1]] + cf.cond_br %arg1, ^bb1, ^bb2 +^bb1: + %cast = memref.cast %arg : memref<10xf32> to memref + // CHECK: return %[[CAST_1]] + return %cast : memref +^bb2: + %cast1 = memref.cast %arg : memref<10xf32> to memref + // CHECK: return %[[CAST_0]] + return %cast1 : memref +} + +// ----- + +// CHECK-LABEL: func.func @hoist_cast_pos_alloc +// CHECK-SAME: %[[ARG0:.*]]: i1 +func.func @hoist_cast_pos_alloc(%arg: i1) -> (memref) { + // CHECK: %[[ALLOC_0:.*]] = memref.alloc() + // CHECK: %[[CAST_0:.*]] = memref.cast %[[ALLOC_0]] + // CHECK: %[[CAST_1:.*]] = memref.cast %[[ALLOC_0]] + // CHECK-NEXT: cf.cond_br %[[ARG0]] + %alloc = memref.alloc() : memref<10xf32> + cf.cond_br %arg, ^bb1, ^bb2 +^bb1: + %cast = memref.cast %alloc : memref<10xf32> to memref + // CHECK: return %[[CAST_1]] + return %cast : memref +^bb2: + %cast1 = memref.cast %alloc : memref<10xf32> to memref + // CHECK: return %[[CAST_0]] + return %cast1 : memref } diff --git a/mlir/test/Dialect/SCF/one-shot-bufferize.mlir b/mlir/test/Dialect/SCF/one-shot-bufferize.mlir index af09dc865e2de..1ae6e3a8a3cf7 100644 --- a/mlir/test/Dialect/SCF/one-shot-bufferize.mlir +++ b/mlir/test/Dialect/SCF/one-shot-bufferize.mlir @@ -922,13 +922,13 @@ func.func @elide_copy_of_non_writing_scf_if(%c: i1, %p1: index, %p2: index, %f: // CHECK-SAME: %[[pred:.*]]: index, %[[b:.*]]: memref<{{.*}}>, %[[c:.*]]: memref<{{.*}}>) -> memref<{{.*}}> func.func @index_switch(%pred: index, %b: tensor<5xf32>, %c: tensor<5xf32>) -> tensor<5xf32> { // Throw in a tensor that bufferizes to a different layout map. - // CHECK: %[[a:.*]] = memref.alloc() {{.*}} : memref<5xf32> + // CHECK: %[[a:.*]] = memref.alloc() {{.*}} : memref<5xf32> + // CHECK: %[[cast:.*]] = memref.cast %[[a]] : memref<5xf32> to memref<5xf32, strided<[?], offset: ?>> %a = bufferization.alloc_tensor() : tensor<5xf32> // CHECK: %[[r:.*]] = scf.index_switch %[[pred]] -> memref<5xf32, strided<[?], offset: ?>> %0 = scf.index_switch %pred -> tensor<5xf32> // CHECK: case 2 { - // CHECK: %[[cast:.*]] = memref.cast %[[a]] : memref<5xf32> to memref<5xf32, strided<[?], offset: ?>> // CHECK: scf.yield %[[cast]] case 2 { scf.yield %a: tensor<5xf32> From 6f0497c201b6d43773a8138b7b38e69ec8ef4b34 Mon Sep 17 00:00:00 2001 From: linuxlonelyeagle <2020382038@qq.com> Date: Tue, 18 Nov 2025 06:45:46 +0000 Subject: [PATCH 3/3] cleanup code. --- mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 2 -- mlir/test/Dialect/SCF/one-shot-bufferize.mlir | 2 +- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp index b489f71b775e0..e94db0ccb11de 100644 --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -13,12 +13,10 @@ #include "mlir/IR/AffineMap.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/Dominance.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/OpDefinition.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" -#include "mlir/Interfaces/FunctionInterfaces.h" #include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Interfaces/Utils/InferIntRangeCommon.h" diff --git a/mlir/test/Dialect/SCF/one-shot-bufferize.mlir b/mlir/test/Dialect/SCF/one-shot-bufferize.mlir index 1ae6e3a8a3cf7..d1c1f1780e353 100644 --- a/mlir/test/Dialect/SCF/one-shot-bufferize.mlir +++ b/mlir/test/Dialect/SCF/one-shot-bufferize.mlir @@ -923,7 +923,7 @@ func.func @elide_copy_of_non_writing_scf_if(%c: i1, %p1: index, %p2: index, %f: func.func @index_switch(%pred: index, %b: tensor<5xf32>, %c: tensor<5xf32>) -> tensor<5xf32> { // Throw in a tensor that bufferizes to a different layout map. // CHECK: %[[a:.*]] = memref.alloc() {{.*}} : memref<5xf32> - // CHECK: %[[cast:.*]] = memref.cast %[[a]] : memref<5xf32> to memref<5xf32, strided<[?], offset: ?>> + // CHECK: %[[cast:.*]] = memref.cast %[[a]] : memref<5xf32> to memref<5xf32, strided<[?], offset: ?>> %a = bufferization.alloc_tensor() : tensor<5xf32> // CHECK: %[[r:.*]] = scf.index_switch %[[pred]] -> memref<5xf32, strided<[?], offset: ?>>