Skip to content

Conversation

NexMing
Copy link
Contributor

@NexMing NexMing commented Oct 15, 2025

Implement folding logic to canonicalize memref.reinterpret_cast ops when offset, sizes and strides are compile-time constants. This removes dynamic shape annotations and produces a static memref form, allowing further lowering and backend optimizations.

@llvmbot
Copy link
Member

llvmbot commented Oct 15, 2025

@llvm/pr-subscribers-mlir

Author: Ming Yan (NexMing)

Changes

Implement folding logic to canonicalize memref.reinterpret_cast ops when offset, sizes and strides are compile-time constants. This removes dynamic shape annotations and produces a static memref form, allowing further lowering and backend optimizations.


Full diff: https://github.com/llvm/llvm-project/pull/163505.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp (+26-1)
  • (modified) mlir/test/Dialect/MemRef/canonicalize.mlir (+21-9)
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index e9bdcda296da5..f914b292eba83 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -2158,11 +2158,36 @@ struct ReinterpretCastOpExtractStridedMetadataFolder
     return success();
   }
 };
+
+struct ReinterpretCastOpConstantFolder
+    : public OpRewritePattern<ReinterpretCastOp> {
+public:
+  using OpRewritePattern<ReinterpretCastOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(ReinterpretCastOp op,
+                                PatternRewriter &rewriter) const override {
+    if (!llvm::any_of(llvm::concat<OpFoldResult>(op.getMixedOffsets(),
+                                                 op.getMixedSizes(),
+                                                 op.getMixedStrides()),
+                      [](OpFoldResult ofr) {
+                        return isa<Value>(ofr) && getConstantIntValue(ofr);
+                      }))
+      return failure();
+
+    auto newReinterpretCast = ReinterpretCastOp::create(
+        rewriter, op->getLoc(), op.getSource(), op.getConstifiedMixedOffset(),
+        op.getConstifiedMixedSizes(), op.getConstifiedMixedStrides());
+
+    rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(), newReinterpretCast);
+    return success();
+  }
+};
 } // namespace
 
 void ReinterpretCastOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                                     MLIRContext *context) {
-  results.add<ReinterpretCastOpExtractStridedMetadataFolder>(context);
+  results.add<ReinterpretCastOpExtractStridedMetadataFolder,
+              ReinterpretCastOpConstantFolder>(context);
 }
 
 FailureOr<std::optional<SmallVector<Value>>>
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index 16b7a5c8bcb08..7160b52af6353 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -911,6 +911,21 @@ func.func @reinterpret_noop(%arg : memref<2x3x4xf32>) -> memref<2x3x4xf32> {
 
 // -----
 
+// CHECK-LABEL: func @reinterpret_constant_fold
+//  CHECK-SAME: (%[[ARG:.*]]: memref<f32>)
+//       CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [0], sizes: [100, 100], strides: [100, 1]
+//       CHECK: %[[CAST:.*]] = memref.cast %[[RES]]
+//       CHECK: return %[[CAST]]
+func.func @reinterpret_constant_fold(%arg0: memref<f32>) -> memref<?x?xf32, strided<[?, ?], offset: ?>> {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c100 = arith.constant 100 : index
+  %reinterpret_cast = memref.reinterpret_cast %arg0 to offset: [%c0], sizes: [%c100, %c100], strides: [%c100, %c1] : memref<f32> to memref<?x?xf32, strided<[?, ?], offset: ?>>
+  return %reinterpret_cast : memref<?x?xf32, strided<[?, ?], offset: ?>>
+}
+
+// -----
+
 // CHECK-LABEL: func @reinterpret_of_reinterpret
 //  CHECK-SAME: (%[[ARG:.*]]: memref<?xi8>, %[[SIZE1:.*]]: index, %[[SIZE2:.*]]: index)
 //       CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [0], sizes: [%[[SIZE2]]], strides: [1]
@@ -996,10 +1011,9 @@ func.func @reinterpret_of_extract_strided_metadata_same_type(%arg0 : memref<?x?x
 // when the strides don't match.
 // CHECK-LABEL: func @reinterpret_of_extract_strided_metadata_w_different_stride
 //  CHECK-SAME: (%[[ARG:.*]]: memref<8x2xf32>)
-//   CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
-//   CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
-//       CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [%[[C0]]], sizes: [4, 2, 2], strides: [1, 1, %[[C1]]]
-//       CHECK: return %[[RES]]
+//       CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [0], sizes: [4, 2, 2], strides: [1, 1, 1]
+//       CHECK: %[[CAST:.*]] = memref.cast %[[RES]]
+//       CHECK: return %[[CAST]]
 func.func @reinterpret_of_extract_strided_metadata_w_different_stride(%arg0 : memref<8x2xf32>) -> memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>> {
   %base, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %arg0 : memref<8x2xf32> -> memref<f32>, index, index, index, index, index
   %m2 = memref.reinterpret_cast %base to offset: [%offset], sizes: [4, 2, 2], strides: [1, 1, %strides#1] : memref<f32> to memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>>
@@ -1011,11 +1025,9 @@ func.func @reinterpret_of_extract_strided_metadata_w_different_stride(%arg0 : me
 // when the offset doesn't match.
 // CHECK-LABEL: func @reinterpret_of_extract_strided_metadata_w_different_offset
 //  CHECK-SAME: (%[[ARG:.*]]: memref<8x2xf32>)
-//   CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index
-//   CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
-//   CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
-//       CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [1], sizes: [%[[C8]], %[[C2]]], strides: [%[[C2]], %[[C1]]]
-//       CHECK: return %[[RES]]
+//       CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [1], sizes: [8, 2], strides: [2, 1]
+//       CHECK: %[[CAST:.*]] = memref.cast %[[RES]]
+//       CHECK: return %[[CAST]]
 func.func @reinterpret_of_extract_strided_metadata_w_different_offset(%arg0 : memref<8x2xf32>) -> memref<?x?xf32, strided<[?, ?], offset: ?>> {
   %base, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %arg0 : memref<8x2xf32> -> memref<f32>, index, index, index, index, index
   %m2 = memref.reinterpret_cast %base to offset: [1], sizes: [%sizes#0, %sizes#1], strides: [%strides#0, %strides#1] : memref<f32> to memref<?x?xf32, strided<[?, ?], offset: ?>>

@llvmbot
Copy link
Member

llvmbot commented Oct 15, 2025

@llvm/pr-subscribers-mlir-memref

Author: Ming Yan (NexMing)

Changes

Implement folding logic to canonicalize memref.reinterpret_cast ops when offset, sizes and strides are compile-time constants. This removes dynamic shape annotations and produces a static memref form, allowing further lowering and backend optimizations.


Full diff: https://github.com/llvm/llvm-project/pull/163505.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp (+26-1)
  • (modified) mlir/test/Dialect/MemRef/canonicalize.mlir (+21-9)
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index e9bdcda296da5..f914b292eba83 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -2158,11 +2158,36 @@ struct ReinterpretCastOpExtractStridedMetadataFolder
     return success();
   }
 };
+
+struct ReinterpretCastOpConstantFolder
+    : public OpRewritePattern<ReinterpretCastOp> {
+public:
+  using OpRewritePattern<ReinterpretCastOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(ReinterpretCastOp op,
+                                PatternRewriter &rewriter) const override {
+    if (!llvm::any_of(llvm::concat<OpFoldResult>(op.getMixedOffsets(),
+                                                 op.getMixedSizes(),
+                                                 op.getMixedStrides()),
+                      [](OpFoldResult ofr) {
+                        return isa<Value>(ofr) && getConstantIntValue(ofr);
+                      }))
+      return failure();
+
+    auto newReinterpretCast = ReinterpretCastOp::create(
+        rewriter, op->getLoc(), op.getSource(), op.getConstifiedMixedOffset(),
+        op.getConstifiedMixedSizes(), op.getConstifiedMixedStrides());
+
+    rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(), newReinterpretCast);
+    return success();
+  }
+};
 } // namespace
 
 void ReinterpretCastOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                                     MLIRContext *context) {
-  results.add<ReinterpretCastOpExtractStridedMetadataFolder>(context);
+  results.add<ReinterpretCastOpExtractStridedMetadataFolder,
+              ReinterpretCastOpConstantFolder>(context);
 }
 
 FailureOr<std::optional<SmallVector<Value>>>
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index 16b7a5c8bcb08..7160b52af6353 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -911,6 +911,21 @@ func.func @reinterpret_noop(%arg : memref<2x3x4xf32>) -> memref<2x3x4xf32> {
 
 // -----
 
+// CHECK-LABEL: func @reinterpret_constant_fold
+//  CHECK-SAME: (%[[ARG:.*]]: memref<f32>)
+//       CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [0], sizes: [100, 100], strides: [100, 1]
+//       CHECK: %[[CAST:.*]] = memref.cast %[[RES]]
+//       CHECK: return %[[CAST]]
+func.func @reinterpret_constant_fold(%arg0: memref<f32>) -> memref<?x?xf32, strided<[?, ?], offset: ?>> {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c100 = arith.constant 100 : index
+  %reinterpret_cast = memref.reinterpret_cast %arg0 to offset: [%c0], sizes: [%c100, %c100], strides: [%c100, %c1] : memref<f32> to memref<?x?xf32, strided<[?, ?], offset: ?>>
+  return %reinterpret_cast : memref<?x?xf32, strided<[?, ?], offset: ?>>
+}
+
+// -----
+
 // CHECK-LABEL: func @reinterpret_of_reinterpret
 //  CHECK-SAME: (%[[ARG:.*]]: memref<?xi8>, %[[SIZE1:.*]]: index, %[[SIZE2:.*]]: index)
 //       CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [0], sizes: [%[[SIZE2]]], strides: [1]
@@ -996,10 +1011,9 @@ func.func @reinterpret_of_extract_strided_metadata_same_type(%arg0 : memref<?x?x
 // when the strides don't match.
 // CHECK-LABEL: func @reinterpret_of_extract_strided_metadata_w_different_stride
 //  CHECK-SAME: (%[[ARG:.*]]: memref<8x2xf32>)
-//   CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
-//   CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
-//       CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [%[[C0]]], sizes: [4, 2, 2], strides: [1, 1, %[[C1]]]
-//       CHECK: return %[[RES]]
+//       CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [0], sizes: [4, 2, 2], strides: [1, 1, 1]
+//       CHECK: %[[CAST:.*]] = memref.cast %[[RES]]
+//       CHECK: return %[[CAST]]
 func.func @reinterpret_of_extract_strided_metadata_w_different_stride(%arg0 : memref<8x2xf32>) -> memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>> {
   %base, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %arg0 : memref<8x2xf32> -> memref<f32>, index, index, index, index, index
   %m2 = memref.reinterpret_cast %base to offset: [%offset], sizes: [4, 2, 2], strides: [1, 1, %strides#1] : memref<f32> to memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>>
@@ -1011,11 +1025,9 @@ func.func @reinterpret_of_extract_strided_metadata_w_different_stride(%arg0 : me
 // when the offset doesn't match.
 // CHECK-LABEL: func @reinterpret_of_extract_strided_metadata_w_different_offset
 //  CHECK-SAME: (%[[ARG:.*]]: memref<8x2xf32>)
-//   CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index
-//   CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
-//   CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
-//       CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [1], sizes: [%[[C8]], %[[C2]]], strides: [%[[C2]], %[[C1]]]
-//       CHECK: return %[[RES]]
+//       CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [1], sizes: [8, 2], strides: [2, 1]
+//       CHECK: %[[CAST:.*]] = memref.cast %[[RES]]
+//       CHECK: return %[[CAST]]
 func.func @reinterpret_of_extract_strided_metadata_w_different_offset(%arg0 : memref<8x2xf32>) -> memref<?x?xf32, strided<[?, ?], offset: ?>> {
   %base, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %arg0 : memref<8x2xf32> -> memref<f32>, index, index, index, index, index
   %m2 = memref.reinterpret_cast %base to offset: [1], sizes: [%sizes#0, %sizes#1], strides: [%strides#0, %strides#1] : memref<f32> to memref<?x?xf32, strided<[?, ?], offset: ?>>

…/strides are constants.

Implement folding logic to canonicalize memref.reinterpret_cast ops when
offset, sizes and strides are compile-time constants. This removes dynamic
shape annotations and produces a static memref form, allowing further
lowering and backend optimizations.
@NexMing NexMing force-pushed the dev/reinterpret-constant-fold branch from 92416e8 to f62bd0f Compare October 15, 2025 06:43
@NexMing NexMing enabled auto-merge (squash) October 17, 2025 10:11
@NexMing NexMing merged commit c988bf8 into llvm:main Oct 17, 2025
10 checks passed
@NexMing NexMing deleted the dev/reinterpret-constant-fold branch October 17, 2025 10:19
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants