Skip to content

Conversation

NexMing
Copy link
Contributor

@NexMing NexMing commented Oct 15, 2025

Implement folding logic to canonicalize memref.reinterpret_cast ops when offset, sizes and strides are compile-time constants. This removes dynamic shape annotations and produces a static memref form, allowing further lowering and backend optimizations.

@llvmbot
Copy link
Member

llvmbot commented Oct 15, 2025

@llvm/pr-subscribers-mlir

Author: Ming Yan (NexMing)

Changes

Implement folding logic to canonicalize memref.reinterpret_cast ops when offset, sizes and strides are compile-time constants. This removes dynamic shape annotations and produces a static memref form, allowing further lowering and backend optimizations.


Full diff: https://github.com/llvm/llvm-project/pull/163505.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp (+26-1)
  • (modified) mlir/test/Dialect/MemRef/canonicalize.mlir (+21-9)
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index e9bdcda296da5..f914b292eba83 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -2158,11 +2158,36 @@ struct ReinterpretCastOpExtractStridedMetadataFolder
     return success();
   }
 };
+
+struct ReinterpretCastOpConstantFolder
+    : public OpRewritePattern<ReinterpretCastOp> {
+public:
+  using OpRewritePattern<ReinterpretCastOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(ReinterpretCastOp op,
+                                PatternRewriter &rewriter) const override {
+    if (!llvm::any_of(llvm::concat<OpFoldResult>(op.getMixedOffsets(),
+                                                 op.getMixedSizes(),
+                                                 op.getMixedStrides()),
+                      [](OpFoldResult ofr) {
+                        return isa<Value>(ofr) && getConstantIntValue(ofr);
+                      }))
+      return failure();
+
+    auto newReinterpretCast = ReinterpretCastOp::create(
+        rewriter, op->getLoc(), op.getSource(), op.getConstifiedMixedOffset(),
+        op.getConstifiedMixedSizes(), op.getConstifiedMixedStrides());
+
+    rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(), newReinterpretCast);
+    return success();
+  }
+};
 } // namespace
 
 void ReinterpretCastOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                                     MLIRContext *context) {
-  results.add<ReinterpretCastOpExtractStridedMetadataFolder>(context);
+  results.add<ReinterpretCastOpExtractStridedMetadataFolder,
+              ReinterpretCastOpConstantFolder>(context);
 }
 
 FailureOr<std::optional<SmallVector<Value>>>
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index 16b7a5c8bcb08..7160b52af6353 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -911,6 +911,21 @@ func.func @reinterpret_noop(%arg : memref<2x3x4xf32>) -> memref<2x3x4xf32> {
 
 // -----
 
+// CHECK-LABEL: func @reinterpret_constant_fold
+//  CHECK-SAME: (%[[ARG:.*]]: memref<f32>)
+//       CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [0], sizes: [100, 100], strides: [100, 1]
+//       CHECK: %[[CAST:.*]] = memref.cast %[[RES]]
+//       CHECK: return %[[CAST]]
+func.func @reinterpret_constant_fold(%arg0: memref<f32>) -> memref<?x?xf32, strided<[?, ?], offset: ?>> {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c100 = arith.constant 100 : index
+  %reinterpret_cast = memref.reinterpret_cast %arg0 to offset: [%c0], sizes: [%c100, %c100], strides: [%c100, %c1] : memref<f32> to memref<?x?xf32, strided<[?, ?], offset: ?>>
+  return %reinterpret_cast : memref<?x?xf32, strided<[?, ?], offset: ?>>
+}
+
+// -----
+
 // CHECK-LABEL: func @reinterpret_of_reinterpret
 //  CHECK-SAME: (%[[ARG:.*]]: memref<?xi8>, %[[SIZE1:.*]]: index, %[[SIZE2:.*]]: index)
 //       CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [0], sizes: [%[[SIZE2]]], strides: [1]
@@ -996,10 +1011,9 @@ func.func @reinterpret_of_extract_strided_metadata_same_type(%arg0 : memref<?x?x
 // when the strides don't match.
 // CHECK-LABEL: func @reinterpret_of_extract_strided_metadata_w_different_stride
 //  CHECK-SAME: (%[[ARG:.*]]: memref<8x2xf32>)
-//   CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
-//   CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
-//       CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [%[[C0]]], sizes: [4, 2, 2], strides: [1, 1, %[[C1]]]
-//       CHECK: return %[[RES]]
+//       CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [0], sizes: [4, 2, 2], strides: [1, 1, 1]
+//       CHECK: %[[CAST:.*]] = memref.cast %[[RES]]
+//       CHECK: return %[[CAST]]
 func.func @reinterpret_of_extract_strided_metadata_w_different_stride(%arg0 : memref<8x2xf32>) -> memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>> {
   %base, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %arg0 : memref<8x2xf32> -> memref<f32>, index, index, index, index, index
   %m2 = memref.reinterpret_cast %base to offset: [%offset], sizes: [4, 2, 2], strides: [1, 1, %strides#1] : memref<f32> to memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>>
@@ -1011,11 +1025,9 @@ func.func @reinterpret_of_extract_strided_metadata_w_different_stride(%arg0 : me
 // when the offset doesn't match.
 // CHECK-LABEL: func @reinterpret_of_extract_strided_metadata_w_different_offset
 //  CHECK-SAME: (%[[ARG:.*]]: memref<8x2xf32>)
-//   CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index
-//   CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
-//   CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
-//       CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [1], sizes: [%[[C8]], %[[C2]]], strides: [%[[C2]], %[[C1]]]
-//       CHECK: return %[[RES]]
+//       CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [1], sizes: [8, 2], strides: [2, 1]
+//       CHECK: %[[CAST:.*]] = memref.cast %[[RES]]
+//       CHECK: return %[[CAST]]
 func.func @reinterpret_of_extract_strided_metadata_w_different_offset(%arg0 : memref<8x2xf32>) -> memref<?x?xf32, strided<[?, ?], offset: ?>> {
   %base, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %arg0 : memref<8x2xf32> -> memref<f32>, index, index, index, index, index
   %m2 = memref.reinterpret_cast %base to offset: [1], sizes: [%sizes#0, %sizes#1], strides: [%strides#0, %strides#1] : memref<f32> to memref<?x?xf32, strided<[?, ?], offset: ?>>

@llvmbot
Copy link
Member

llvmbot commented Oct 15, 2025

@llvm/pr-subscribers-mlir-memref

Author: Ming Yan (NexMing)

Changes

Implement folding logic to canonicalize memref.reinterpret_cast ops when offset, sizes and strides are compile-time constants. This removes dynamic shape annotations and produces a static memref form, allowing further lowering and backend optimizations.


Full diff: https://github.com/llvm/llvm-project/pull/163505.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp (+26-1)
  • (modified) mlir/test/Dialect/MemRef/canonicalize.mlir (+21-9)
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index e9bdcda296da5..f914b292eba83 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -2158,11 +2158,36 @@ struct ReinterpretCastOpExtractStridedMetadataFolder
     return success();
   }
 };
+
+struct ReinterpretCastOpConstantFolder
+    : public OpRewritePattern<ReinterpretCastOp> {
+public:
+  using OpRewritePattern<ReinterpretCastOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(ReinterpretCastOp op,
+                                PatternRewriter &rewriter) const override {
+    if (!llvm::any_of(llvm::concat<OpFoldResult>(op.getMixedOffsets(),
+                                                 op.getMixedSizes(),
+                                                 op.getMixedStrides()),
+                      [](OpFoldResult ofr) {
+                        return isa<Value>(ofr) && getConstantIntValue(ofr);
+                      }))
+      return failure();
+
+    auto newReinterpretCast = ReinterpretCastOp::create(
+        rewriter, op->getLoc(), op.getSource(), op.getConstifiedMixedOffset(),
+        op.getConstifiedMixedSizes(), op.getConstifiedMixedStrides());
+
+    rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(), newReinterpretCast);
+    return success();
+  }
+};
 } // namespace
 
 void ReinterpretCastOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                                     MLIRContext *context) {
-  results.add<ReinterpretCastOpExtractStridedMetadataFolder>(context);
+  results.add<ReinterpretCastOpExtractStridedMetadataFolder,
+              ReinterpretCastOpConstantFolder>(context);
 }
 
 FailureOr<std::optional<SmallVector<Value>>>
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index 16b7a5c8bcb08..7160b52af6353 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -911,6 +911,21 @@ func.func @reinterpret_noop(%arg : memref<2x3x4xf32>) -> memref<2x3x4xf32> {
 
 // -----
 
+// CHECK-LABEL: func @reinterpret_constant_fold
+//  CHECK-SAME: (%[[ARG:.*]]: memref<f32>)
+//       CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [0], sizes: [100, 100], strides: [100, 1]
+//       CHECK: %[[CAST:.*]] = memref.cast %[[RES]]
+//       CHECK: return %[[CAST]]
+func.func @reinterpret_constant_fold(%arg0: memref<f32>) -> memref<?x?xf32, strided<[?, ?], offset: ?>> {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c100 = arith.constant 100 : index
+  %reinterpret_cast = memref.reinterpret_cast %arg0 to offset: [%c0], sizes: [%c100, %c100], strides: [%c100, %c1] : memref<f32> to memref<?x?xf32, strided<[?, ?], offset: ?>>
+  return %reinterpret_cast : memref<?x?xf32, strided<[?, ?], offset: ?>>
+}
+
+// -----
+
 // CHECK-LABEL: func @reinterpret_of_reinterpret
 //  CHECK-SAME: (%[[ARG:.*]]: memref<?xi8>, %[[SIZE1:.*]]: index, %[[SIZE2:.*]]: index)
 //       CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [0], sizes: [%[[SIZE2]]], strides: [1]
@@ -996,10 +1011,9 @@ func.func @reinterpret_of_extract_strided_metadata_same_type(%arg0 : memref<?x?x
 // when the strides don't match.
 // CHECK-LABEL: func @reinterpret_of_extract_strided_metadata_w_different_stride
 //  CHECK-SAME: (%[[ARG:.*]]: memref<8x2xf32>)
-//   CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
-//   CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
-//       CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [%[[C0]]], sizes: [4, 2, 2], strides: [1, 1, %[[C1]]]
-//       CHECK: return %[[RES]]
+//       CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [0], sizes: [4, 2, 2], strides: [1, 1, 1]
+//       CHECK: %[[CAST:.*]] = memref.cast %[[RES]]
+//       CHECK: return %[[CAST]]
 func.func @reinterpret_of_extract_strided_metadata_w_different_stride(%arg0 : memref<8x2xf32>) -> memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>> {
   %base, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %arg0 : memref<8x2xf32> -> memref<f32>, index, index, index, index, index
   %m2 = memref.reinterpret_cast %base to offset: [%offset], sizes: [4, 2, 2], strides: [1, 1, %strides#1] : memref<f32> to memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>>
@@ -1011,11 +1025,9 @@ func.func @reinterpret_of_extract_strided_metadata_w_different_stride(%arg0 : me
 // when the offset doesn't match.
 // CHECK-LABEL: func @reinterpret_of_extract_strided_metadata_w_different_offset
 //  CHECK-SAME: (%[[ARG:.*]]: memref<8x2xf32>)
-//   CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index
-//   CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
-//   CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
-//       CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [1], sizes: [%[[C8]], %[[C2]]], strides: [%[[C2]], %[[C1]]]
-//       CHECK: return %[[RES]]
+//       CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [1], sizes: [8, 2], strides: [2, 1]
+//       CHECK: %[[CAST:.*]] = memref.cast %[[RES]]
+//       CHECK: return %[[CAST]]
 func.func @reinterpret_of_extract_strided_metadata_w_different_offset(%arg0 : memref<8x2xf32>) -> memref<?x?xf32, strided<[?, ?], offset: ?>> {
   %base, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %arg0 : memref<8x2xf32> -> memref<f32>, index, index, index, index, index
   %m2 = memref.reinterpret_cast %base to offset: [1], sizes: [%sizes#0, %sizes#1], strides: [%strides#0, %strides#1] : memref<f32> to memref<?x?xf32, strided<[?, ?], offset: ?>>

…/strides are constants.

Implement folding logic to canonicalize memref.reinterpret_cast ops when
offset, sizes and strides are compile-time constants. This removes dynamic
shape annotations and produces a static memref form, allowing further
lowering and backend optimizations.
@NexMing NexMing force-pushed the dev/reinterpret-constant-fold branch from 92416e8 to f62bd0f Compare October 15, 2025 06:43
@NexMing NexMing enabled auto-merge (squash) October 17, 2025 10:11
@NexMing NexMing merged commit c988bf8 into llvm:main Oct 17, 2025
10 checks passed
@NexMing NexMing deleted the dev/reinterpret-constant-fold branch October 17, 2025 10:19
@clementval
Copy link
Contributor

This test is not triggering the verifier before the canonicalization pattern but after it does. Should the verifier be stricter or should the pattern fails on such op?

func.func @reinterpret_constant_fold2(%arg0: memref<?x?x?xi32>, %arg1 : index) -> memref<?x?x?xi32, strided<[?, ?, ?], offset: ?>> {
  %c0 = arith.constant 0 : index
  %c-1 = arith.constant -1 : index
  %reinterpret_cast = memref.reinterpret_cast %arg0 to offset: [%c0], sizes: [%arg1, %arg1, %c-1], strides: [%arg1, %arg1, %arg1] : memref<?x?x?xi32> to memref<?x?x?xi32, strided<[?, ?, ?], offset: ?>>
  return %reinterpret_cast : memref<?x?x?xi32, strided<[?, ?, ?], offset: ?>>
}

@krzysz00
Copy link
Contributor

The size of a memref is not permitted to be negative

Statically negative memref sizes have every right to be an error

@krzysz00
Copy link
Contributor

That is - option 3, the test is UB

@clementval
Copy link
Contributor

The size of a memref is not permitted to be negative

Statically negative memref sizes have every right to be an error

Ok so it could be enforced in the verifier. We are experiencing with some FIR to MemRef passes and Fir represents the dynamic size as -1 where MemRef represents it as std::numeric_limits<int64_t>::min(). So I guess we need to align our representation.

@matthias-springer
Copy link
Member

The canonicalization pattern from this PR must be updated: if the pattern would generate an op that does not verify, it must abort.

@krzysz00
Copy link
Contributor

I'd argue that the op was already invalid before canonicalization - just not verifiably invalid?

The size of a dimension - even a dynamic one - should always be nonnegative?

So that c_-1 should be %runtone.var.with.sise or a my.trully.arbitrary.size.marker

@clementval
Copy link
Contributor

clementval commented Oct 19, 2025

This is verifiable. The value comes from a constant so it can be checked in the verifier.

The size of a dimension - even a dynamic one - should always be nonnegative?

Dynamic size is defined by a negative value std::numeric_limits<int64_t>::min() in ShapedType.

So in my opinion, the verifier should be updated to check this or the canonicalization pattern should bail out in case the new op does not verify.

@NexMing
Copy link
Contributor Author

NexMing commented Oct 20, 2025

Ok so it could be enforced in the verifier. We are experiencing with some FIR to MemRef passes and Fir represents the dynamic size as -1 where MemRef represents it as std::numeric_limits<int64_t>::min(). So I guess we need to align our representation.

Why not directly use the ShapedType::kDynamic dynamic size type? If you also cannot determine the current dynamic size at runtime, I suggest using the ub.poison value to represent it.

@NexMing
Copy link
Contributor Author

NexMing commented Oct 20, 2025

Ok so it could be enforced in the verifier. We are experiencing with some FIR to MemRef passes and Fir represents the dynamic size as -1 where MemRef represents it as std::numeric_limits<int64_t>::min(). So I guess we need to align our representation.

What a coincidence — I am also working on the FIR to standard MLIR conversion.
I have already started experimenting, and preliminary results can be seen in my repository: https://github.com/NexMing/llvm-project/tree/dev/fir-to-mlir
I plan to push it to the main pipeline soon.

@jeanPerier
Copy link
Contributor

jeanPerier commented Oct 20, 2025

Ok so it could be enforced in the verifier. We are experiencing with some FIR to MemRef passes and Fir represents the dynamic size as -1 where MemRef represents it as std::numeric_limits<int64_t>::min(). So I guess we need to align our representation.

Why not directly use the ShapedType::kDynamic dynamic size type? If you also cannot determine the current dynamic size at runtime, I suggest using the ub.poison value to represent it.

-1 is used in the Fortran descriptors to encode assumed-size arrays (ARRAY(n, m,*)). In these arrays, the extent of the outer dimension (last one in Fortran) will never be known and does not matter (not needed to generate pointer arithmetic). There are a lot of restricictions with these arrays (the user can basically only index of pass them and should not do anything that would require the compiler to know the effective size of the array).

This value may be later be used in code to detect assumed-size (e.g SELECT RANK). This specific value is also mandated in C-Fortran interoperability contexts (Fortran 2023 section 18.5.3).

So in the FIR "equivalent" of memref, fir.box, -1 is a well specified and expected value. We do not want to use poison. Another example of its importance is runtime bounds checking (which flang does not have yet). -1 will allow bounds checking to know that there is no way to check the index when addressing the last dimension of an assumed-size array.

I think the issue here is that memref is probably not designed currently to support this Fortran use case.
@NexMing, how are you translating fir.embox to memref to deal with assumed-size arrays?

@matthias-springer
Copy link
Member

matthias-springer commented Oct 20, 2025

I'd argue that the op was already invalid before canonicalization - just not verifiably invalid?

What about this IR?

%sz = ...  // could be negative
scf.if (%sz >= 0) {
  %0 = memref.reinterpret_cast %m to offset: [%off], sizes: [%sz], strides: [%str] : memref<f32> to memref<?xf32, strided<[?], offset: ?>>
  ...
}

@krzysz00
Copy link
Contributor

What about this IR?

%sz = ...  // could be negative
scf.if (%sz >= 0) {
  %0 = memref.reinterpret_cast %m to offset: [%off], sizes: [%sz], strides: [%str] : memref<f32> to memref<?xf32, strided<[?], offset: ?>>
  ...
}

That's valid, but not checkable at compile time.

I think the issue here is that memref is probably not designed currently to support this Fortran use case.

Yeah, memref doesn't allow negative sizes and there's a decent amount of code that - for example - assumes that you can take the product of all memref.dims and get a meaningful value at runtime.

What I'd use is some value like fir.unknown.size which isn't -1 but lowers to it when you're rewriting out of memref.

@jeanPerier
Copy link
Contributor

What I'd use is some value like fir.unknown.size which isn't -1 but lowers to it when you're rewriting out of memref.

Thanks, that is an interesting solution that would definitely solves the FIR to Memref issue at hand!

there's a decent amount of code that - for example - assumes that you can take the product of all memref.dims and get a meaningful value at runtime.

Right, and using fir.unknown.size would not solve the issue that any core MLIR pass working on memref may just create a temporary copy for whatever reason assuming it can always do it, while this is just not possible with Fortran assumed size.
That is why I tend to agree with @matthias-springer that passes that rewrite fir.box to memref should only do it when the memref specifications supports representing the related Fortran array.

This means that Flang just cannot rely on MLIR memref dialect and related optimization passes to optimize code containing assumed-size manipulations (which are rather common in old and not so old Fortran code).

That's valid, but not checkable at compile time.

Side note that I think this kind of things where, after constant folding/inlining, illegal constant values inputs reaches operations in code that cannot easily be proved to be reachable is a bit of a hassle. I think that it would be nice if there was an option for the verification pass to replace illegal operation by traps (or whatever the driver wants, personally I am not a fan of unreachable, it makes debugging bad user code very hard because of code pruning). I guess some verifier errors should always be fatal, but bad input that does not have to be compile time constant should not IMHO. We are not hitting the issue too much with flang, so I have not spent time trying to come up with something, but why not at some point if there is some interest.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants