-
Notifications
You must be signed in to change notification settings - Fork 14.9k
[mlir][memref] Canonicalize memref.reinterpret_cast when offset/sizes/strides are constants. #163505
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir Author: Ming Yan (NexMing) ChangesImplement folding logic to canonicalize memref.reinterpret_cast ops when offset, sizes and strides are compile-time constants. This removes dynamic shape annotations and produces a static memref form, allowing further lowering and backend optimizations. Full diff: https://github.com/llvm/llvm-project/pull/163505.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index e9bdcda296da5..f914b292eba83 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -2158,11 +2158,36 @@ struct ReinterpretCastOpExtractStridedMetadataFolder
return success();
}
};
+
+struct ReinterpretCastOpConstantFolder
+ : public OpRewritePattern<ReinterpretCastOp> {
+public:
+ using OpRewritePattern<ReinterpretCastOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(ReinterpretCastOp op,
+ PatternRewriter &rewriter) const override {
+ if (!llvm::any_of(llvm::concat<OpFoldResult>(op.getMixedOffsets(),
+ op.getMixedSizes(),
+ op.getMixedStrides()),
+ [](OpFoldResult ofr) {
+ return isa<Value>(ofr) && getConstantIntValue(ofr);
+ }))
+ return failure();
+
+ auto newReinterpretCast = ReinterpretCastOp::create(
+ rewriter, op->getLoc(), op.getSource(), op.getConstifiedMixedOffset(),
+ op.getConstifiedMixedSizes(), op.getConstifiedMixedStrides());
+
+ rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(), newReinterpretCast);
+ return success();
+ }
+};
} // namespace
void ReinterpretCastOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<ReinterpretCastOpExtractStridedMetadataFolder>(context);
+ results.add<ReinterpretCastOpExtractStridedMetadataFolder,
+ ReinterpretCastOpConstantFolder>(context);
}
FailureOr<std::optional<SmallVector<Value>>>
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index 16b7a5c8bcb08..7160b52af6353 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -911,6 +911,21 @@ func.func @reinterpret_noop(%arg : memref<2x3x4xf32>) -> memref<2x3x4xf32> {
// -----
+// CHECK-LABEL: func @reinterpret_constant_fold
+// CHECK-SAME: (%[[ARG:.*]]: memref<f32>)
+// CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [0], sizes: [100, 100], strides: [100, 1]
+// CHECK: %[[CAST:.*]] = memref.cast %[[RES]]
+// CHECK: return %[[CAST]]
+func.func @reinterpret_constant_fold(%arg0: memref<f32>) -> memref<?x?xf32, strided<[?, ?], offset: ?>> {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c100 = arith.constant 100 : index
+ %reinterpret_cast = memref.reinterpret_cast %arg0 to offset: [%c0], sizes: [%c100, %c100], strides: [%c100, %c1] : memref<f32> to memref<?x?xf32, strided<[?, ?], offset: ?>>
+ return %reinterpret_cast : memref<?x?xf32, strided<[?, ?], offset: ?>>
+}
+
+// -----
+
// CHECK-LABEL: func @reinterpret_of_reinterpret
// CHECK-SAME: (%[[ARG:.*]]: memref<?xi8>, %[[SIZE1:.*]]: index, %[[SIZE2:.*]]: index)
// CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [0], sizes: [%[[SIZE2]]], strides: [1]
@@ -996,10 +1011,9 @@ func.func @reinterpret_of_extract_strided_metadata_same_type(%arg0 : memref<?x?x
// when the strides don't match.
// CHECK-LABEL: func @reinterpret_of_extract_strided_metadata_w_different_stride
// CHECK-SAME: (%[[ARG:.*]]: memref<8x2xf32>)
-// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
-// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
-// CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [%[[C0]]], sizes: [4, 2, 2], strides: [1, 1, %[[C1]]]
-// CHECK: return %[[RES]]
+// CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [0], sizes: [4, 2, 2], strides: [1, 1, 1]
+// CHECK: %[[CAST:.*]] = memref.cast %[[RES]]
+// CHECK: return %[[CAST]]
func.func @reinterpret_of_extract_strided_metadata_w_different_stride(%arg0 : memref<8x2xf32>) -> memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>> {
%base, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %arg0 : memref<8x2xf32> -> memref<f32>, index, index, index, index, index
%m2 = memref.reinterpret_cast %base to offset: [%offset], sizes: [4, 2, 2], strides: [1, 1, %strides#1] : memref<f32> to memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>>
@@ -1011,11 +1025,9 @@ func.func @reinterpret_of_extract_strided_metadata_w_different_stride(%arg0 : me
// when the offset doesn't match.
// CHECK-LABEL: func @reinterpret_of_extract_strided_metadata_w_different_offset
// CHECK-SAME: (%[[ARG:.*]]: memref<8x2xf32>)
-// CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index
-// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
-// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
-// CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [1], sizes: [%[[C8]], %[[C2]]], strides: [%[[C2]], %[[C1]]]
-// CHECK: return %[[RES]]
+// CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [1], sizes: [8, 2], strides: [2, 1]
+// CHECK: %[[CAST:.*]] = memref.cast %[[RES]]
+// CHECK: return %[[CAST]]
func.func @reinterpret_of_extract_strided_metadata_w_different_offset(%arg0 : memref<8x2xf32>) -> memref<?x?xf32, strided<[?, ?], offset: ?>> {
%base, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %arg0 : memref<8x2xf32> -> memref<f32>, index, index, index, index, index
%m2 = memref.reinterpret_cast %base to offset: [1], sizes: [%sizes#0, %sizes#1], strides: [%strides#0, %strides#1] : memref<f32> to memref<?x?xf32, strided<[?, ?], offset: ?>>
|
@llvm/pr-subscribers-mlir-memref Author: Ming Yan (NexMing) ChangesImplement folding logic to canonicalize memref.reinterpret_cast ops when offset, sizes and strides are compile-time constants. This removes dynamic shape annotations and produces a static memref form, allowing further lowering and backend optimizations. Full diff: https://github.com/llvm/llvm-project/pull/163505.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index e9bdcda296da5..f914b292eba83 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -2158,11 +2158,36 @@ struct ReinterpretCastOpExtractStridedMetadataFolder
return success();
}
};
+
+struct ReinterpretCastOpConstantFolder
+ : public OpRewritePattern<ReinterpretCastOp> {
+public:
+ using OpRewritePattern<ReinterpretCastOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(ReinterpretCastOp op,
+ PatternRewriter &rewriter) const override {
+ if (!llvm::any_of(llvm::concat<OpFoldResult>(op.getMixedOffsets(),
+ op.getMixedSizes(),
+ op.getMixedStrides()),
+ [](OpFoldResult ofr) {
+ return isa<Value>(ofr) && getConstantIntValue(ofr);
+ }))
+ return failure();
+
+ auto newReinterpretCast = ReinterpretCastOp::create(
+ rewriter, op->getLoc(), op.getSource(), op.getConstifiedMixedOffset(),
+ op.getConstifiedMixedSizes(), op.getConstifiedMixedStrides());
+
+ rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(), newReinterpretCast);
+ return success();
+ }
+};
} // namespace
void ReinterpretCastOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<ReinterpretCastOpExtractStridedMetadataFolder>(context);
+ results.add<ReinterpretCastOpExtractStridedMetadataFolder,
+ ReinterpretCastOpConstantFolder>(context);
}
FailureOr<std::optional<SmallVector<Value>>>
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index 16b7a5c8bcb08..7160b52af6353 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -911,6 +911,21 @@ func.func @reinterpret_noop(%arg : memref<2x3x4xf32>) -> memref<2x3x4xf32> {
// -----
+// CHECK-LABEL: func @reinterpret_constant_fold
+// CHECK-SAME: (%[[ARG:.*]]: memref<f32>)
+// CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [0], sizes: [100, 100], strides: [100, 1]
+// CHECK: %[[CAST:.*]] = memref.cast %[[RES]]
+// CHECK: return %[[CAST]]
+func.func @reinterpret_constant_fold(%arg0: memref<f32>) -> memref<?x?xf32, strided<[?, ?], offset: ?>> {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c100 = arith.constant 100 : index
+ %reinterpret_cast = memref.reinterpret_cast %arg0 to offset: [%c0], sizes: [%c100, %c100], strides: [%c100, %c1] : memref<f32> to memref<?x?xf32, strided<[?, ?], offset: ?>>
+ return %reinterpret_cast : memref<?x?xf32, strided<[?, ?], offset: ?>>
+}
+
+// -----
+
// CHECK-LABEL: func @reinterpret_of_reinterpret
// CHECK-SAME: (%[[ARG:.*]]: memref<?xi8>, %[[SIZE1:.*]]: index, %[[SIZE2:.*]]: index)
// CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [0], sizes: [%[[SIZE2]]], strides: [1]
@@ -996,10 +1011,9 @@ func.func @reinterpret_of_extract_strided_metadata_same_type(%arg0 : memref<?x?x
// when the strides don't match.
// CHECK-LABEL: func @reinterpret_of_extract_strided_metadata_w_different_stride
// CHECK-SAME: (%[[ARG:.*]]: memref<8x2xf32>)
-// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
-// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
-// CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [%[[C0]]], sizes: [4, 2, 2], strides: [1, 1, %[[C1]]]
-// CHECK: return %[[RES]]
+// CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [0], sizes: [4, 2, 2], strides: [1, 1, 1]
+// CHECK: %[[CAST:.*]] = memref.cast %[[RES]]
+// CHECK: return %[[CAST]]
func.func @reinterpret_of_extract_strided_metadata_w_different_stride(%arg0 : memref<8x2xf32>) -> memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>> {
%base, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %arg0 : memref<8x2xf32> -> memref<f32>, index, index, index, index, index
%m2 = memref.reinterpret_cast %base to offset: [%offset], sizes: [4, 2, 2], strides: [1, 1, %strides#1] : memref<f32> to memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>>
@@ -1011,11 +1025,9 @@ func.func @reinterpret_of_extract_strided_metadata_w_different_stride(%arg0 : me
// when the offset doesn't match.
// CHECK-LABEL: func @reinterpret_of_extract_strided_metadata_w_different_offset
// CHECK-SAME: (%[[ARG:.*]]: memref<8x2xf32>)
-// CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index
-// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
-// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
-// CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [1], sizes: [%[[C8]], %[[C2]]], strides: [%[[C2]], %[[C1]]]
-// CHECK: return %[[RES]]
+// CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [1], sizes: [8, 2], strides: [2, 1]
+// CHECK: %[[CAST:.*]] = memref.cast %[[RES]]
+// CHECK: return %[[CAST]]
func.func @reinterpret_of_extract_strided_metadata_w_different_offset(%arg0 : memref<8x2xf32>) -> memref<?x?xf32, strided<[?, ?], offset: ?>> {
%base, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %arg0 : memref<8x2xf32> -> memref<f32>, index, index, index, index, index
%m2 = memref.reinterpret_cast %base to offset: [1], sizes: [%sizes#0, %sizes#1], strides: [%strides#0, %strides#1] : memref<f32> to memref<?x?xf32, strided<[?, ?], offset: ?>>
|
…/strides are constants. Implement folding logic to canonicalize memref.reinterpret_cast ops when offset, sizes and strides are compile-time constants. This removes dynamic shape annotations and produces a static memref form, allowing further lowering and backend optimizations.
92416e8
to
f62bd0f
Compare
Implement folding logic to canonicalize memref.reinterpret_cast ops when offset, sizes and strides are compile-time constants. This removes dynamic shape annotations and produces a static memref form, allowing further lowering and backend optimizations.