From f62bd0f649b55750c1d61b9732b99ad15f65092d Mon Sep 17 00:00:00 2001 From: yanming Date: Wed, 15 Oct 2025 14:10:24 +0800 Subject: [PATCH 1/4] [mlir][memref] Canonicalize memref.reinterpret_cast when offset/sizes/strides are constants. Implement folding logic to canonicalize memref.reinterpret_cast ops when offset, sizes and strides are compile-time constants. This removes dynamic shape annotations and produces a static memref form, allowing further lowering and backend optimizations. --- mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 24 ++++++++++++++++- mlir/test/Dialect/MemRef/canonicalize.mlir | 30 +++++++++++++++------- 2 files changed, 44 insertions(+), 10 deletions(-) diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp index e9bdcda296da5..de797c4789480 100644 --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -2158,11 +2158,33 @@ struct ReinterpretCastOpExtractStridedMetadataFolder return success(); } }; + +struct ReinterpretCastOpConstantFolder + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ReinterpretCastOp op, + PatternRewriter &rewriter) const override { + if (!llvm::any_of(llvm::concat(op.getOffsets(), op.getSizes(), + op.getStrides()), + getConstantIntValue)) + return failure(); + + auto newReinterpretCast = ReinterpretCastOp::create( + rewriter, op->getLoc(), op.getSource(), op.getConstifiedMixedOffset(), + op.getConstifiedMixedSizes(), op.getConstifiedMixedStrides()); + + rewriter.replaceOpWithNewOp(op, op.getType(), newReinterpretCast); + return success(); + } +}; } // namespace void ReinterpretCastOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add(context); + results.add(context); } FailureOr>> diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir index 16b7a5c8bcb08..7160b52af6353 100644 --- a/mlir/test/Dialect/MemRef/canonicalize.mlir +++ b/mlir/test/Dialect/MemRef/canonicalize.mlir @@ -911,6 +911,21 @@ func.func @reinterpret_noop(%arg : memref<2x3x4xf32>) -> memref<2x3x4xf32> { // ----- +// CHECK-LABEL: func @reinterpret_constant_fold +// CHECK-SAME: (%[[ARG:.*]]: memref) +// CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [0], sizes: [100, 100], strides: [100, 1] +// CHECK: %[[CAST:.*]] = memref.cast %[[RES]] +// CHECK: return %[[CAST]] +func.func @reinterpret_constant_fold(%arg0: memref) -> memref> { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c100 = arith.constant 100 : index + %reinterpret_cast = memref.reinterpret_cast %arg0 to offset: [%c0], sizes: [%c100, %c100], strides: [%c100, %c1] : memref to memref> + return %reinterpret_cast : memref> +} + +// ----- + // CHECK-LABEL: func @reinterpret_of_reinterpret // CHECK-SAME: (%[[ARG:.*]]: memref, %[[SIZE1:.*]]: index, %[[SIZE2:.*]]: index) // CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [0], sizes: [%[[SIZE2]]], strides: [1] @@ -996,10 +1011,9 @@ func.func @reinterpret_of_extract_strided_metadata_same_type(%arg0 : memref) -// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index -// CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [%[[C0]]], sizes: [4, 2, 2], strides: [1, 1, %[[C1]]] -// CHECK: return %[[RES]] +// CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [0], sizes: [4, 2, 2], strides: [1, 1, 1] +// CHECK: %[[CAST:.*]] = memref.cast %[[RES]] +// CHECK: return %[[CAST]] func.func @reinterpret_of_extract_strided_metadata_w_different_stride(%arg0 : memref<8x2xf32>) -> memref> { %base, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %arg0 : memref<8x2xf32> -> memref, index, index, index, index, index %m2 = memref.reinterpret_cast %base to offset: [%offset], sizes: [4, 2, 2], strides: [1, 1, %strides#1] : memref to memref> @@ -1011,11 +1025,9 @@ func.func @reinterpret_of_extract_strided_metadata_w_different_stride(%arg0 : me // when the offset doesn't match. // CHECK-LABEL: func @reinterpret_of_extract_strided_metadata_w_different_offset // CHECK-SAME: (%[[ARG:.*]]: memref<8x2xf32>) -// CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index -// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index -// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index -// CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [1], sizes: [%[[C8]], %[[C2]]], strides: [%[[C2]], %[[C1]]] -// CHECK: return %[[RES]] +// CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [1], sizes: [8, 2], strides: [2, 1] +// CHECK: %[[CAST:.*]] = memref.cast %[[RES]] +// CHECK: return %[[CAST]] func.func @reinterpret_of_extract_strided_metadata_w_different_offset(%arg0 : memref<8x2xf32>) -> memref> { %base, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %arg0 : memref<8x2xf32> -> memref, index, index, index, index, index %m2 = memref.reinterpret_cast %base to offset: [1], sizes: [%sizes#0, %sizes#1], strides: [%strides#0, %strides#1] : memref to memref> From efc5e8264776ca5685e562517befce3f25c0e8fe Mon Sep 17 00:00:00 2001 From: yanming Date: Fri, 17 Oct 2025 17:34:52 +0800 Subject: [PATCH 2/4] Ensure that success() is returned only if the IR has been modified. --- mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp index de797c4789480..dbfe9988533ce 100644 --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -2166,14 +2166,22 @@ struct ReinterpretCastOpConstantFolder LogicalResult matchAndRewrite(ReinterpretCastOp op, PatternRewriter &rewriter) const override { - if (!llvm::any_of(llvm::concat(op.getOffsets(), op.getSizes(), - op.getStrides()), - getConstantIntValue)) + unsigned srcStaticCount = llvm::count_if( + llvm::concat(op.getMixedOffsets(), op.getMixedSizes(), + op.getMixedStrides()), + [](OpFoldResult ofr) { return isa(ofr); }); + + SmallVector offsets = {op.getConstifiedMixedOffset()}; + SmallVector sizes = op.getConstifiedMixedSizes(); + SmallVector strides = op.getConstifiedMixedStrides(); + + if (srcStaticCount == + llvm::count_if(llvm::concat(offsets, sizes, strides), + [](OpFoldResult ofr) { return isa(ofr); })) return failure(); auto newReinterpretCast = ReinterpretCastOp::create( - rewriter, op->getLoc(), op.getSource(), op.getConstifiedMixedOffset(), - op.getConstifiedMixedSizes(), op.getConstifiedMixedStrides()); + rewriter, op->getLoc(), op.getSource(), offsets[0], sizes, strides); rewriter.replaceOpWithNewOp(op, op.getType(), newReinterpretCast); return success(); From 2f5f8207730ff05cb392e6900100a0a3ae64df47 Mon Sep 17 00:00:00 2001 From: yanming Date: Fri, 17 Oct 2025 18:06:37 +0800 Subject: [PATCH 3/4] Add comments --- mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp index dbfe9988533ce..36b995354c38c 100644 --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -2166,6 +2166,10 @@ struct ReinterpretCastOpConstantFolder LogicalResult matchAndRewrite(ReinterpretCastOp op, PatternRewriter &rewriter) const override { + // TODO: Using counting comparison instead of direct comparison because + // getMixedValues (and consequently ReinterpretCastOp::getMixed...) returns + // IntegerAttrs, while constifyIndexValues (and consequently + // ReinterpretCastOp::getConstifiedMixed...) returns IndexAttrs. unsigned srcStaticCount = llvm::count_if( llvm::concat(op.getMixedOffsets(), op.getMixedSizes(), op.getMixedStrides()), From 3a83ed042f19f8f450894c897974fb8314ea5ae8 Mon Sep 17 00:00:00 2001 From: yanming Date: Fri, 17 Oct 2025 18:09:58 +0800 Subject: [PATCH 4/4] Move comments --- mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp index 36b995354c38c..4383761f07dae 100644 --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -2166,10 +2166,6 @@ struct ReinterpretCastOpConstantFolder LogicalResult matchAndRewrite(ReinterpretCastOp op, PatternRewriter &rewriter) const override { - // TODO: Using counting comparison instead of direct comparison because - // getMixedValues (and consequently ReinterpretCastOp::getMixed...) returns - // IntegerAttrs, while constifyIndexValues (and consequently - // ReinterpretCastOp::getConstifiedMixed...) returns IndexAttrs. unsigned srcStaticCount = llvm::count_if( llvm::concat(op.getMixedOffsets(), op.getMixedSizes(), op.getMixedStrides()), @@ -2179,6 +2175,10 @@ struct ReinterpretCastOpConstantFolder SmallVector sizes = op.getConstifiedMixedSizes(); SmallVector strides = op.getConstifiedMixedStrides(); + // TODO: Using counting comparison instead of direct comparison because + // getMixedValues (and therefore ReinterpretCastOp::getMixed...) returns + // IntegerAttrs, while constifyIndexValues (and therefore + // ReinterpretCastOp::getConstifiedMixed...) returns IndexAttrs. if (srcStaticCount == llvm::count_if(llvm::concat(offsets, sizes, strides), [](OpFoldResult ofr) { return isa(ofr); }))