-
Notifications
You must be signed in to change notification settings - Fork 15.4k
[mlir][memref] Add HoistCastPos pattern to castOp #168337
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[mlir][memref] Add HoistCastPos pattern to castOp #168337
Conversation
|
@llvm/pr-subscribers-mlir-scf @llvm/pr-subscribers-mlir Author: lonely eagle (linuxlonelyeagle) ChangesThe foldUseDominateCast function is used to eliminate redundant casts. Full diff: https://github.com/llvm/llvm-project/pull/168337.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 1c21a2f270da6..aafd908c7af7e 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -13,10 +13,12 @@
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Dominance.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
+#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Interfaces/Utils/InferIntRangeCommon.h"
@@ -793,8 +795,32 @@ bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
return false;
}
+static OpFoldResult foldUseDominateCast(CastOp castOp) {
+ auto funcOp = castOp->getParentOfType<FunctionOpInterface>();
+ if (!funcOp)
+ return {};
+ auto castOps = castOp->getBlock()->getOps<CastOp>();
+ CastOp dominateCastOp = castOp;
+ SmallVector<CastOp> ops(castOps);
+ mlir::DominanceInfo dominanceInfo(castOp);
+ for (auto it : castOps) {
+ if (it.getSource() == dominateCastOp.getSource() &&
+ it.getDest().getType() == dominateCastOp.getDest().getType() &&
+ dominanceInfo.dominates(it.getOperation(),
+ dominateCastOp.getOperation())) {
+ dominateCastOp = it;
+ }
+ }
+ return dominateCastOp == castOp ? Value() : dominateCastOp.getResult();
+}
+
OpFoldResult CastOp::fold(FoldAdaptor adaptor) {
- return succeeded(foldMemRefCast(*this)) ? getResult() : Value();
+ OpFoldResult result;
+ if (OpFoldResult value = foldUseDominateCast(*this))
+ result = value;
+ if (succeeded(foldMemRefCast(*this)))
+ result = getResult();
+ return result;
}
FailureOr<std::optional<SmallVector<Value>>>
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index 313090272ef90..3638b8d4ac701 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -1367,3 +1367,22 @@ func.func @non_fold_view_same_source_res_types(%0: memref<?xi8>, %arg0 : index)
%res = memref.view %0[%c0][%arg0] : memref<?xi8> to memref<?xi8>
return %res : memref<?xi8>
}
+
+// -----
+
+func.func @fold_use_dominate_cast_foo(%arg0: memref<?xf32, strided<[?], offset:?>>) {
+ return
+}
+
+// CHECK-LABEL: func @fold_use_dominate_cast(
+// CHECK-SAME: %[[ARG0:.*]]: memref<?xf32>)
+func.func @fold_use_dominate_cast(%arg: memref<?xf32>) {
+ // CHECK: %[[CAST_0:.*]] = memref.cast %[[ARG0]]
+ %cast0 = memref.cast %arg : memref<?xf32> to memref<?xf32, strided<[?], offset:?>>
+ %cast1 = memref.cast %arg : memref<?xf32> to memref<?xf32, strided<[?], offset:?>>
+ // CHECK: call @fold_use_dominate_cast_foo(%[[CAST_0]])
+ call @fold_use_dominate_cast_foo(%cast0) : (memref<?xf32, strided<[?], offset:?>>) -> ()
+ // CHECK: call @fold_use_dominate_cast_foo(%[[CAST_0]])
+ call @fold_use_dominate_cast_foo(%cast1) : (memref<?xf32, strided<[?], offset:?>>) -> ()
+ return
+}
|
|
Following case don't work by use CSE pass. The following case is straightforward; I cannot provide you with the complete case because the IR is too complex.https://discourse.llvm.org/t/will-ops-without-side-effects-be-reordered-when-running-the-pass/85222,I went back through our previous chat history. I will reimplement it, and thank you for your review suggestions above. cc: @joker-eph |
|
This issue is resolved by hoisting the position of |
🐧 Linux x64 Test Results
|
| return success(); | ||
| } | ||
| return failure(); | ||
| } else { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| } else { | |
| } |
Nit: no else-after-return
|
We could consider that "cast are canonicalized to be closest to their definition" or something like that, but:
|
You are right. This issue has evolved from the original problem into "How to use the CSE pass more efficiently?".
CSE work on it. We can implement a more generic pass based on SSA dominance to hoist the Pure op.To be perfectly honest, I'm not entirely sure how difficult it is.However, I'd be quite happy to make it happen. We do need to consider more people's suggestions regarding this issue.
For me, this implementation is already sufficient. "closest to their definition".It means they are in a block, so I can use CSE, right? |
The foldUseDominateCast function is used to eliminate redundant casts.