-
Notifications
You must be signed in to change notification settings - Fork 14.9k
[mlir][memref] Canonicalize memref.reinterpret_cast when offset/sizes/strides are constants. #163505
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir Author: Ming Yan (NexMing) ChangesImplement folding logic to canonicalize memref.reinterpret_cast ops when offset, sizes and strides are compile-time constants. This removes dynamic shape annotations and produces a static memref form, allowing further lowering and backend optimizations. Full diff: https://github.com/llvm/llvm-project/pull/163505.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index e9bdcda296da5..f914b292eba83 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -2158,11 +2158,36 @@ struct ReinterpretCastOpExtractStridedMetadataFolder
return success();
}
};
+
+struct ReinterpretCastOpConstantFolder
+ : public OpRewritePattern<ReinterpretCastOp> {
+public:
+ using OpRewritePattern<ReinterpretCastOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(ReinterpretCastOp op,
+ PatternRewriter &rewriter) const override {
+ if (!llvm::any_of(llvm::concat<OpFoldResult>(op.getMixedOffsets(),
+ op.getMixedSizes(),
+ op.getMixedStrides()),
+ [](OpFoldResult ofr) {
+ return isa<Value>(ofr) && getConstantIntValue(ofr);
+ }))
+ return failure();
+
+ auto newReinterpretCast = ReinterpretCastOp::create(
+ rewriter, op->getLoc(), op.getSource(), op.getConstifiedMixedOffset(),
+ op.getConstifiedMixedSizes(), op.getConstifiedMixedStrides());
+
+ rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(), newReinterpretCast);
+ return success();
+ }
+};
} // namespace
void ReinterpretCastOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<ReinterpretCastOpExtractStridedMetadataFolder>(context);
+ results.add<ReinterpretCastOpExtractStridedMetadataFolder,
+ ReinterpretCastOpConstantFolder>(context);
}
FailureOr<std::optional<SmallVector<Value>>>
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index 16b7a5c8bcb08..7160b52af6353 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -911,6 +911,21 @@ func.func @reinterpret_noop(%arg : memref<2x3x4xf32>) -> memref<2x3x4xf32> {
// -----
+// CHECK-LABEL: func @reinterpret_constant_fold
+// CHECK-SAME: (%[[ARG:.*]]: memref<f32>)
+// CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [0], sizes: [100, 100], strides: [100, 1]
+// CHECK: %[[CAST:.*]] = memref.cast %[[RES]]
+// CHECK: return %[[CAST]]
+func.func @reinterpret_constant_fold(%arg0: memref<f32>) -> memref<?x?xf32, strided<[?, ?], offset: ?>> {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c100 = arith.constant 100 : index
+ %reinterpret_cast = memref.reinterpret_cast %arg0 to offset: [%c0], sizes: [%c100, %c100], strides: [%c100, %c1] : memref<f32> to memref<?x?xf32, strided<[?, ?], offset: ?>>
+ return %reinterpret_cast : memref<?x?xf32, strided<[?, ?], offset: ?>>
+}
+
+// -----
+
// CHECK-LABEL: func @reinterpret_of_reinterpret
// CHECK-SAME: (%[[ARG:.*]]: memref<?xi8>, %[[SIZE1:.*]]: index, %[[SIZE2:.*]]: index)
// CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [0], sizes: [%[[SIZE2]]], strides: [1]
@@ -996,10 +1011,9 @@ func.func @reinterpret_of_extract_strided_metadata_same_type(%arg0 : memref<?x?x
// when the strides don't match.
// CHECK-LABEL: func @reinterpret_of_extract_strided_metadata_w_different_stride
// CHECK-SAME: (%[[ARG:.*]]: memref<8x2xf32>)
-// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
-// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
-// CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [%[[C0]]], sizes: [4, 2, 2], strides: [1, 1, %[[C1]]]
-// CHECK: return %[[RES]]
+// CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [0], sizes: [4, 2, 2], strides: [1, 1, 1]
+// CHECK: %[[CAST:.*]] = memref.cast %[[RES]]
+// CHECK: return %[[CAST]]
func.func @reinterpret_of_extract_strided_metadata_w_different_stride(%arg0 : memref<8x2xf32>) -> memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>> {
%base, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %arg0 : memref<8x2xf32> -> memref<f32>, index, index, index, index, index
%m2 = memref.reinterpret_cast %base to offset: [%offset], sizes: [4, 2, 2], strides: [1, 1, %strides#1] : memref<f32> to memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>>
@@ -1011,11 +1025,9 @@ func.func @reinterpret_of_extract_strided_metadata_w_different_stride(%arg0 : me
// when the offset doesn't match.
// CHECK-LABEL: func @reinterpret_of_extract_strided_metadata_w_different_offset
// CHECK-SAME: (%[[ARG:.*]]: memref<8x2xf32>)
-// CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index
-// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
-// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
-// CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [1], sizes: [%[[C8]], %[[C2]]], strides: [%[[C2]], %[[C1]]]
-// CHECK: return %[[RES]]
+// CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [1], sizes: [8, 2], strides: [2, 1]
+// CHECK: %[[CAST:.*]] = memref.cast %[[RES]]
+// CHECK: return %[[CAST]]
func.func @reinterpret_of_extract_strided_metadata_w_different_offset(%arg0 : memref<8x2xf32>) -> memref<?x?xf32, strided<[?, ?], offset: ?>> {
%base, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %arg0 : memref<8x2xf32> -> memref<f32>, index, index, index, index, index
%m2 = memref.reinterpret_cast %base to offset: [1], sizes: [%sizes#0, %sizes#1], strides: [%strides#0, %strides#1] : memref<f32> to memref<?x?xf32, strided<[?, ?], offset: ?>>
|
@llvm/pr-subscribers-mlir-memref Author: Ming Yan (NexMing) ChangesImplement folding logic to canonicalize memref.reinterpret_cast ops when offset, sizes and strides are compile-time constants. This removes dynamic shape annotations and produces a static memref form, allowing further lowering and backend optimizations. Full diff: https://github.com/llvm/llvm-project/pull/163505.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index e9bdcda296da5..f914b292eba83 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -2158,11 +2158,36 @@ struct ReinterpretCastOpExtractStridedMetadataFolder
return success();
}
};
+
+struct ReinterpretCastOpConstantFolder
+ : public OpRewritePattern<ReinterpretCastOp> {
+public:
+ using OpRewritePattern<ReinterpretCastOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(ReinterpretCastOp op,
+ PatternRewriter &rewriter) const override {
+ if (!llvm::any_of(llvm::concat<OpFoldResult>(op.getMixedOffsets(),
+ op.getMixedSizes(),
+ op.getMixedStrides()),
+ [](OpFoldResult ofr) {
+ return isa<Value>(ofr) && getConstantIntValue(ofr);
+ }))
+ return failure();
+
+ auto newReinterpretCast = ReinterpretCastOp::create(
+ rewriter, op->getLoc(), op.getSource(), op.getConstifiedMixedOffset(),
+ op.getConstifiedMixedSizes(), op.getConstifiedMixedStrides());
+
+ rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(), newReinterpretCast);
+ return success();
+ }
+};
} // namespace
void ReinterpretCastOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<ReinterpretCastOpExtractStridedMetadataFolder>(context);
+ results.add<ReinterpretCastOpExtractStridedMetadataFolder,
+ ReinterpretCastOpConstantFolder>(context);
}
FailureOr<std::optional<SmallVector<Value>>>
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index 16b7a5c8bcb08..7160b52af6353 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -911,6 +911,21 @@ func.func @reinterpret_noop(%arg : memref<2x3x4xf32>) -> memref<2x3x4xf32> {
// -----
+// CHECK-LABEL: func @reinterpret_constant_fold
+// CHECK-SAME: (%[[ARG:.*]]: memref<f32>)
+// CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [0], sizes: [100, 100], strides: [100, 1]
+// CHECK: %[[CAST:.*]] = memref.cast %[[RES]]
+// CHECK: return %[[CAST]]
+func.func @reinterpret_constant_fold(%arg0: memref<f32>) -> memref<?x?xf32, strided<[?, ?], offset: ?>> {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c100 = arith.constant 100 : index
+ %reinterpret_cast = memref.reinterpret_cast %arg0 to offset: [%c0], sizes: [%c100, %c100], strides: [%c100, %c1] : memref<f32> to memref<?x?xf32, strided<[?, ?], offset: ?>>
+ return %reinterpret_cast : memref<?x?xf32, strided<[?, ?], offset: ?>>
+}
+
+// -----
+
// CHECK-LABEL: func @reinterpret_of_reinterpret
// CHECK-SAME: (%[[ARG:.*]]: memref<?xi8>, %[[SIZE1:.*]]: index, %[[SIZE2:.*]]: index)
// CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [0], sizes: [%[[SIZE2]]], strides: [1]
@@ -996,10 +1011,9 @@ func.func @reinterpret_of_extract_strided_metadata_same_type(%arg0 : memref<?x?x
// when the strides don't match.
// CHECK-LABEL: func @reinterpret_of_extract_strided_metadata_w_different_stride
// CHECK-SAME: (%[[ARG:.*]]: memref<8x2xf32>)
-// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
-// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
-// CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [%[[C0]]], sizes: [4, 2, 2], strides: [1, 1, %[[C1]]]
-// CHECK: return %[[RES]]
+// CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [0], sizes: [4, 2, 2], strides: [1, 1, 1]
+// CHECK: %[[CAST:.*]] = memref.cast %[[RES]]
+// CHECK: return %[[CAST]]
func.func @reinterpret_of_extract_strided_metadata_w_different_stride(%arg0 : memref<8x2xf32>) -> memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>> {
%base, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %arg0 : memref<8x2xf32> -> memref<f32>, index, index, index, index, index
%m2 = memref.reinterpret_cast %base to offset: [%offset], sizes: [4, 2, 2], strides: [1, 1, %strides#1] : memref<f32> to memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>>
@@ -1011,11 +1025,9 @@ func.func @reinterpret_of_extract_strided_metadata_w_different_stride(%arg0 : me
// when the offset doesn't match.
// CHECK-LABEL: func @reinterpret_of_extract_strided_metadata_w_different_offset
// CHECK-SAME: (%[[ARG:.*]]: memref<8x2xf32>)
-// CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index
-// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
-// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
-// CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [1], sizes: [%[[C8]], %[[C2]]], strides: [%[[C2]], %[[C1]]]
-// CHECK: return %[[RES]]
+// CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [1], sizes: [8, 2], strides: [2, 1]
+// CHECK: %[[CAST:.*]] = memref.cast %[[RES]]
+// CHECK: return %[[CAST]]
func.func @reinterpret_of_extract_strided_metadata_w_different_offset(%arg0 : memref<8x2xf32>) -> memref<?x?xf32, strided<[?, ?], offset: ?>> {
%base, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %arg0 : memref<8x2xf32> -> memref<f32>, index, index, index, index, index
%m2 = memref.reinterpret_cast %base to offset: [1], sizes: [%sizes#0, %sizes#1], strides: [%strides#0, %strides#1] : memref<f32> to memref<?x?xf32, strided<[?, ?], offset: ?>>
|
…/strides are constants. Implement folding logic to canonicalize memref.reinterpret_cast ops when offset, sizes and strides are compile-time constants. This removes dynamic shape annotations and produces a static memref form, allowing further lowering and backend optimizations.
92416e8
to
f62bd0f
Compare
This test is not triggering the verifier before the canonicalization pattern but after it does. Should the verifier be stricter or should the pattern fails on such op?
|
The size of a memref is not permitted to be negative Statically negative memref sizes have every right to be an error |
That is - option 3, the test is UB |
Ok so it could be enforced in the verifier. We are experiencing with some FIR to MemRef passes and Fir represents the dynamic size as |
The canonicalization pattern from this PR must be updated: if the pattern would generate an op that does not verify, it must abort. |
I'd argue that the op was already invalid before canonicalization - just not verifiably invalid? The size of a dimension - even a dynamic one - should always be nonnegative? So that c_-1 should be %runtone.var.with.sise or a |
This is verifiable. The value comes from a constant so it can be checked in the verifier.
Dynamic size is defined by a negative value So in my opinion, the verifier should be updated to check this or the canonicalization pattern should bail out in case the new op does not verify. |
Why not directly use the |
What a coincidence — I am also working on the FIR to standard MLIR conversion. |
This value may be later be used in code to detect assumed-size (e.g SELECT RANK). This specific value is also mandated in C-Fortran interoperability contexts (Fortran 2023 section 18.5.3). So in the FIR "equivalent" of memref, I think the issue here is that memref is probably not designed currently to support this Fortran use case. |
What about this IR?
|
That's valid, but not checkable at compile time.
Yeah, memref doesn't allow negative sizes and there's a decent amount of code that - for example - assumes that you can take the product of all What I'd use is some value like |
Thanks, that is an interesting solution that would definitely solves the FIR to Memref issue at hand!
Right, and using This means that Flang just cannot rely on MLIR memref dialect and related optimization passes to optimize code containing assumed-size manipulations (which are rather common in old and not so old Fortran code).
Side note that I think this kind of things where, after constant folding/inlining, illegal constant values inputs reaches operations in code that cannot easily be proved to be reachable is a bit of a hassle. I think that it would be nice if there was an option for the verification pass to replace illegal operation by traps (or whatever the driver wants, personally I am not a fan of unreachable, it makes debugging bad user code very hard because of code pruning). I guess some verifier errors should always be fatal, but bad input that does not have to be compile time constant should not IMHO. We are not hitting the issue too much with flang, so I have not spent time trying to come up with something, but why not at some point if there is some interest. |
Implement folding logic to canonicalize memref.reinterpret_cast ops when offset, sizes and strides are compile-time constants. This removes dynamic shape annotations and produces a static memref form, allowing further lowering and backend optimizations.