Skip to content

Conversation

@linuxlonelyeagle
Copy link
Member

The foldUseDominateCast function is used to eliminate redundant casts.

@llvmbot
Copy link
Member

llvmbot commented Nov 17, 2025

@llvm/pr-subscribers-mlir-scf
@llvm/pr-subscribers-mlir-memref

@llvm/pr-subscribers-mlir

Author: lonely eagle (linuxlonelyeagle)

Changes

The foldUseDominateCast function is used to eliminate redundant casts.


Full diff: https://github.com/llvm/llvm-project/pull/168337.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp (+27-1)
  • (modified) mlir/test/Dialect/MemRef/canonicalize.mlir (+19)
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 1c21a2f270da6..aafd908c7af7e 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -13,10 +13,12 @@
 #include "mlir/IR/AffineMap.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Dominance.h"
 #include "mlir/IR/Matchers.h"
 #include "mlir/IR/OpDefinition.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/TypeUtilities.h"
+#include "mlir/Interfaces/FunctionInterfaces.h"
 #include "mlir/Interfaces/InferTypeOpInterface.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 #include "mlir/Interfaces/Utils/InferIntRangeCommon.h"
@@ -793,8 +795,32 @@ bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
   return false;
 }
 
+static OpFoldResult foldUseDominateCast(CastOp castOp) {
+  auto funcOp = castOp->getParentOfType<FunctionOpInterface>();
+  if (!funcOp)
+    return {};
+  auto castOps = castOp->getBlock()->getOps<CastOp>();
+  CastOp dominateCastOp = castOp;
+  SmallVector<CastOp> ops(castOps);
+  mlir::DominanceInfo dominanceInfo(castOp);
+  for (auto it : castOps) {
+    if (it.getSource() == dominateCastOp.getSource() &&
+        it.getDest().getType() == dominateCastOp.getDest().getType() &&
+        dominanceInfo.dominates(it.getOperation(),
+                                dominateCastOp.getOperation())) {
+      dominateCastOp = it;
+    }
+  }
+  return dominateCastOp == castOp ? Value() : dominateCastOp.getResult();
+}
+
 OpFoldResult CastOp::fold(FoldAdaptor adaptor) {
-  return succeeded(foldMemRefCast(*this)) ? getResult() : Value();
+  OpFoldResult result;
+  if (OpFoldResult value = foldUseDominateCast(*this))
+    result = value;
+  if (succeeded(foldMemRefCast(*this)))
+    result = getResult();
+  return result;
 }
 
 FailureOr<std::optional<SmallVector<Value>>>
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index 313090272ef90..3638b8d4ac701 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -1367,3 +1367,22 @@ func.func @non_fold_view_same_source_res_types(%0: memref<?xi8>, %arg0 : index)
   %res = memref.view %0[%c0][%arg0] : memref<?xi8> to memref<?xi8>
   return %res : memref<?xi8>
 }
+
+// -----
+
+func.func @fold_use_dominate_cast_foo(%arg0: memref<?xf32, strided<[?], offset:?>>) {
+  return
+}
+
+// CHECK-LABEL: func @fold_use_dominate_cast(
+//  CHECK-SAME:   %[[ARG0:.*]]: memref<?xf32>)
+func.func @fold_use_dominate_cast(%arg: memref<?xf32>) {
+  // CHECK: %[[CAST_0:.*]] = memref.cast %[[ARG0]]
+  %cast0 = memref.cast %arg : memref<?xf32> to memref<?xf32, strided<[?], offset:?>>
+  %cast1 = memref.cast %arg : memref<?xf32> to memref<?xf32, strided<[?], offset:?>>
+  // CHECK: call @fold_use_dominate_cast_foo(%[[CAST_0]])
+  call @fold_use_dominate_cast_foo(%cast0) : (memref<?xf32, strided<[?], offset:?>>) -> ()
+  // CHECK: call @fold_use_dominate_cast_foo(%[[CAST_0]])
+  call @fold_use_dominate_cast_foo(%cast1) : (memref<?xf32, strided<[?], offset:?>>) -> () 
+  return
+}

@linuxlonelyeagle
Copy link
Member Author

Following case don't work by use CSE pass. The following case is straightforward; I cannot provide you with the complete case because the IR is too complex.https://discourse.llvm.org/t/will-ops-without-side-effects-be-reordered-when-running-the-pass/85222,I went back through our previous chat history. I will reimplement it, and thank you for your review suggestions above. cc: @joker-eph

func.func @fold_use_dominate_cast_foo(%arg0: memref<?xf32, strided<[?], offset:?>>) {
  return
}

func.func @fold_use_dominate_cast(%arg: memref<?xf32>, %arg1: index) -> index {
  %cast1 = memref.cast %arg : memref<?xf32> to memref<?xf32, strided<[?], offset:?>>
  affine.for %i = 0 to 10 {
    %cast0 = memref.cast %arg : memref<?xf32> to memref<?xf32, strided<[?], offset:?>>
    func.call @fold_use_dominate_cast_foo(%cast0) : (memref<?xf32, strided<[?], offset:?>>) -> ()
  }
  func.call @fold_use_dominate_cast_foo(%cast1) : (memref<?xf32, strided<[?], offset:?>>) -> () 
  %c1 = arith.constant 1 : index
  %ret = arith.addi %arg1, %c1 : index
  return %ret :index
}

@linuxlonelyeagle linuxlonelyeagle changed the title [mlir][memref] Add foldUseDominateCast function to castOp [mlir][memref] Add HoistCastPos pattern to castOp Nov 18, 2025
@linuxlonelyeagle
Copy link
Member Author

This issue is resolved by hoisting the position of castOp and subsequently employing CSE.

@github-actions
Copy link

🐧 Linux x64 Test Results

  • 7081 tests passed
  • 594 tests skipped

return success();
}
return failure();
} else {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
} else {
}

Nit: no else-after-return

@joker-eph
Copy link
Collaborator

We could consider that "cast are canonicalized to be closest to their definition" or something like that, but:

  1. I'm not entirely sure about this one: what makes cast "special" with this property? Why not other ops? This likely requires more opinions here.
  2. You're only achieving this partially (only if in a different block).

@linuxlonelyeagle
Copy link
Member Author

We could consider that "cast are canonicalized to be closest to their definition" or something like that, but:

  1. I'm not entirely sure about this one: what makes cast "special" with this property? Why not other ops? This likely requires more opinions here.

You are right. This issue has evolved from the original problem into "How to use the CSE pass more efficiently?".
https://discourse.llvm.org/t/will-ops-without-side-effects-be-reordered-when-running-the-pass/85222/
As we discussed earlier, if an Op is a Pure Op, we have the opportunity to hoist its position.

  • How to use the CSE pass more efficiently?
    Following code CSE don't work.
func.func @hoist_cast_pos(%arg: memref<10xf32>, %arg1: i1) -> (memref<?xf32>) {
  cf.cond_br %arg1, ^bb1, ^bb2
^bb1:
  %cast = memref.cast %arg : memref<10xf32> to memref<?xf32>
  return %cast : memref<?xf32>
^bb2:
  %cast1 = memref.cast %arg : memref<10xf32> to memref<?xf32>
  return %cast1 : memref<?xf32> 
}

CSE work on it.

func.func @hoist_cast_pos(%arg: memref<10xf32>, %arg1: i1) -> (memref<?xf32>) {
   %cast = memref.cast %arg : memref<10xf32> to memref<?xf32>  
   %cast1 = memref.cast %arg : memref<10xf32> to memref<?xf32>
  cf.cond_br %arg1, ^bb1, ^bb2
^bb1:
  return %cast : memref<?xf32>
^bb2:
  return %cast1 : memref<?xf32> 
}

//  run mlir-opt hoist_cast_pos.mlir -cse
func.func @hoist_cast_pos(%arg: memref<10xf32>, %arg1: i1) -> (memref<?xf32>) {
   %cast = memref.cast %arg : memref<10xf32> to memref<?xf32>  
  cf.cond_br %arg1, ^bb1, ^bb2
^bb1:
  return %cast : memref<?xf32>
^bb2:
  return %cast : memref<?xf32> 
}

We can implement a more generic pass based on SSA dominance to hoist the Pure op.To be perfectly honest, I'm not entirely sure how difficult it is.However, I'd be quite happy to make it happen. We do need to consider more people's suggestions regarding this issue.

  1. You're only achieving this partially (only if in a different block).

For me, this implementation is already sufficient. "closest to their definition".It means they are in a block, so I can use CSE, right?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants