Skip to content

Conversation

@matthias-springer
Copy link
Member

Implement runtime op verification for memref.copy. Only ranked memrefs are verified at the moment.

@llvmbot
Copy link
Member

llvmbot commented Mar 8, 2025

@llvm/pr-subscribers-mlir

Author: Matthias Springer (matthias-springer)

Changes

Implement runtime op verification for memref.copy. Only ranked memrefs are verified at the moment.


Full diff: https://github.com/llvm/llvm-project/pull/130437.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp (+48)
  • (added) mlir/test/Integration/Dialect/MemRef/copy-runtime-verification.mlir (+28)
diff --git a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
index ceea27a35a225..c604af249ba2e 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
@@ -182,6 +182,53 @@ struct CastOpInterface
   }
 };
 
+struct CopyOpInterface
+    : public RuntimeVerifiableOpInterface::ExternalModel<CopyOpInterface,
+                                                         CopyOp> {
+  void generateRuntimeVerification(Operation *op, OpBuilder &builder,
+                                   Location loc) const {
+    auto copyOp = cast<CopyOp>(op);
+    BaseMemRefType sourceType = copyOp.getSource().getType();
+    BaseMemRefType targetType = copyOp.getTarget().getType();
+    auto rankedSourceType = dyn_cast<MemRefType>(sourceType);
+    auto rankedTargetType = dyn_cast<MemRefType>(targetType);
+
+    // TODO: Verification for unranked memrefs is not supported yet.
+    if (!rankedSourceType || !rankedTargetType)
+      return;
+
+    assert(sourceType.getRank() == targetType.getRank() && "rank mismatch");
+    for (int64_t i = 0, e = sourceType.getRank(); i < e; ++i) {
+      // Fully static dimensions in both source and target operand are already
+      // verified by the op verifier.
+      if (!rankedSourceType.isDynamicDim(i) &&
+          !rankedTargetType.isDynamicDim(i))
+        continue;
+      Value sourceDim;
+      if (rankedSourceType.isDynamicDim(i)) {
+        sourceDim = builder.create<DimOp>(loc, copyOp.getSource(), i);
+      } else {
+        sourceDim = builder.create<arith::ConstantIndexOp>(
+            loc, rankedSourceType.getDimSize(i));
+      }
+      Value targetDim;
+      if (rankedTargetType.isDynamicDim(i)) {
+        targetDim = builder.create<DimOp>(loc, copyOp.getTarget(), i);
+      } else {
+        targetDim = builder.create<arith::ConstantIndexOp>(
+            loc, rankedTargetType.getDimSize(i));
+      }
+      Value sameDimSize = builder.create<arith::CmpIOp>(
+          loc, arith::CmpIPredicate::eq, sourceDim, targetDim);
+      builder.create<cf::AssertOp>(
+          loc, sameDimSize,
+          RuntimeVerifiableOpInterface::generateErrorMessage(
+              op, "size of " + std::to_string(i) +
+                      "-th source/target dim does not match"));
+    }
+  }
+};
+
 struct DimOpInterface
     : public RuntimeVerifiableOpInterface::ExternalModel<DimOpInterface,
                                                          DimOp> {
@@ -383,6 +430,7 @@ void mlir::memref::registerRuntimeVerifiableOpInterfaceExternalModels(
     AssumeAlignmentOp::attachInterface<AssumeAlignmentOpInterface>(*ctx);
     AtomicRMWOp::attachInterface<LoadStoreOpInterface<AtomicRMWOp>>(*ctx);
     CastOp::attachInterface<CastOpInterface>(*ctx);
+    CopyOp::attachInterface<CopyOpInterface>(*ctx);
     DimOp::attachInterface<DimOpInterface>(*ctx);
     ExpandShapeOp::attachInterface<ExpandShapeOpInterface>(*ctx);
     GenericAtomicRMWOp::attachInterface<
diff --git a/mlir/test/Integration/Dialect/MemRef/copy-runtime-verification.mlir b/mlir/test/Integration/Dialect/MemRef/copy-runtime-verification.mlir
new file mode 100644
index 0000000000000..95b9db2832cee
--- /dev/null
+++ b/mlir/test/Integration/Dialect/MemRef/copy-runtime-verification.mlir
@@ -0,0 +1,28 @@
+// RUN: mlir-opt %s -generate-runtime-verification \
+// RUN:     -expand-strided-metadata \
+// RUN:     -test-cf-assert \
+// RUN:     -convert-to-llvm | \
+// RUN: mlir-runner -e main -entry-point-result=void \
+// RUN:     -shared-libs=%mlir_runner_utils 2>&1 | \
+// RUN: FileCheck %s
+
+// Put memref.copy in a function, otherwise the memref.cast may fold.
+func.func @memcpy_helper(%src: memref<?xf32>, %dest: memref<?xf32>) {
+  memref.copy %src, %dest : memref<?xf32> to memref<?xf32>
+  return
+}
+
+func.func @main() {
+  %alloca1 = memref.alloca() : memref<4xf32>
+  %alloca2 = memref.alloca() : memref<5xf32>
+  %cast1 = memref.cast %alloca1 : memref<4xf32> to memref<?xf32>
+  %cast2 = memref.cast %alloca2 : memref<5xf32> to memref<?xf32>
+
+  //      CHECK: ERROR: Runtime op verification failed
+  // CHECK-NEXT: "memref.copy"(%{{.*}}, %{{.*}}) : (memref<?xf32>, memref<?xf32>) -> ()
+  // CHECK-NEXT: ^ size of 0-th source/target dim does not match
+  // CHECK-NEXT: Location: loc({{.*}})
+  call @memcpy_helper(%cast1, %cast2) : (memref<?xf32>, memref<?xf32>) -> ()
+
+  return
+}

@llvmbot
Copy link
Member

llvmbot commented Mar 8, 2025

@llvm/pr-subscribers-mlir-memref

Author: Matthias Springer (matthias-springer)

Changes

Implement runtime op verification for memref.copy. Only ranked memrefs are verified at the moment.


Full diff: https://github.com/llvm/llvm-project/pull/130437.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp (+48)
  • (added) mlir/test/Integration/Dialect/MemRef/copy-runtime-verification.mlir (+28)
diff --git a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
index ceea27a35a225..c604af249ba2e 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
@@ -182,6 +182,53 @@ struct CastOpInterface
   }
 };
 
+struct CopyOpInterface
+    : public RuntimeVerifiableOpInterface::ExternalModel<CopyOpInterface,
+                                                         CopyOp> {
+  void generateRuntimeVerification(Operation *op, OpBuilder &builder,
+                                   Location loc) const {
+    auto copyOp = cast<CopyOp>(op);
+    BaseMemRefType sourceType = copyOp.getSource().getType();
+    BaseMemRefType targetType = copyOp.getTarget().getType();
+    auto rankedSourceType = dyn_cast<MemRefType>(sourceType);
+    auto rankedTargetType = dyn_cast<MemRefType>(targetType);
+
+    // TODO: Verification for unranked memrefs is not supported yet.
+    if (!rankedSourceType || !rankedTargetType)
+      return;
+
+    assert(sourceType.getRank() == targetType.getRank() && "rank mismatch");
+    for (int64_t i = 0, e = sourceType.getRank(); i < e; ++i) {
+      // Fully static dimensions in both source and target operand are already
+      // verified by the op verifier.
+      if (!rankedSourceType.isDynamicDim(i) &&
+          !rankedTargetType.isDynamicDim(i))
+        continue;
+      Value sourceDim;
+      if (rankedSourceType.isDynamicDim(i)) {
+        sourceDim = builder.create<DimOp>(loc, copyOp.getSource(), i);
+      } else {
+        sourceDim = builder.create<arith::ConstantIndexOp>(
+            loc, rankedSourceType.getDimSize(i));
+      }
+      Value targetDim;
+      if (rankedTargetType.isDynamicDim(i)) {
+        targetDim = builder.create<DimOp>(loc, copyOp.getTarget(), i);
+      } else {
+        targetDim = builder.create<arith::ConstantIndexOp>(
+            loc, rankedTargetType.getDimSize(i));
+      }
+      Value sameDimSize = builder.create<arith::CmpIOp>(
+          loc, arith::CmpIPredicate::eq, sourceDim, targetDim);
+      builder.create<cf::AssertOp>(
+          loc, sameDimSize,
+          RuntimeVerifiableOpInterface::generateErrorMessage(
+              op, "size of " + std::to_string(i) +
+                      "-th source/target dim does not match"));
+    }
+  }
+};
+
 struct DimOpInterface
     : public RuntimeVerifiableOpInterface::ExternalModel<DimOpInterface,
                                                          DimOp> {
@@ -383,6 +430,7 @@ void mlir::memref::registerRuntimeVerifiableOpInterfaceExternalModels(
     AssumeAlignmentOp::attachInterface<AssumeAlignmentOpInterface>(*ctx);
     AtomicRMWOp::attachInterface<LoadStoreOpInterface<AtomicRMWOp>>(*ctx);
     CastOp::attachInterface<CastOpInterface>(*ctx);
+    CopyOp::attachInterface<CopyOpInterface>(*ctx);
     DimOp::attachInterface<DimOpInterface>(*ctx);
     ExpandShapeOp::attachInterface<ExpandShapeOpInterface>(*ctx);
     GenericAtomicRMWOp::attachInterface<
diff --git a/mlir/test/Integration/Dialect/MemRef/copy-runtime-verification.mlir b/mlir/test/Integration/Dialect/MemRef/copy-runtime-verification.mlir
new file mode 100644
index 0000000000000..95b9db2832cee
--- /dev/null
+++ b/mlir/test/Integration/Dialect/MemRef/copy-runtime-verification.mlir
@@ -0,0 +1,28 @@
+// RUN: mlir-opt %s -generate-runtime-verification \
+// RUN:     -expand-strided-metadata \
+// RUN:     -test-cf-assert \
+// RUN:     -convert-to-llvm | \
+// RUN: mlir-runner -e main -entry-point-result=void \
+// RUN:     -shared-libs=%mlir_runner_utils 2>&1 | \
+// RUN: FileCheck %s
+
+// Put memref.copy in a function, otherwise the memref.cast may fold.
+func.func @memcpy_helper(%src: memref<?xf32>, %dest: memref<?xf32>) {
+  memref.copy %src, %dest : memref<?xf32> to memref<?xf32>
+  return
+}
+
+func.func @main() {
+  %alloca1 = memref.alloca() : memref<4xf32>
+  %alloca2 = memref.alloca() : memref<5xf32>
+  %cast1 = memref.cast %alloca1 : memref<4xf32> to memref<?xf32>
+  %cast2 = memref.cast %alloca2 : memref<5xf32> to memref<?xf32>
+
+  //      CHECK: ERROR: Runtime op verification failed
+  // CHECK-NEXT: "memref.copy"(%{{.*}}, %{{.*}}) : (memref<?xf32>, memref<?xf32>) -> ()
+  // CHECK-NEXT: ^ size of 0-th source/target dim does not match
+  // CHECK-NEXT: Location: loc({{.*}})
+  call @memcpy_helper(%cast1, %cast2) : (memref<?xf32>, memref<?xf32>) -> ()
+
+  return
+}

@matthias-springer matthias-springer force-pushed the users/matthias-springer/memref_copy_verification branch from 7bb852c to 2a3b442 Compare March 11, 2025 10:57
@matthias-springer matthias-springer changed the base branch from users/matthias-springer/atomic_rmw_verification to main March 11, 2025 10:57
@matthias-springer matthias-springer merged commit 1b455df into main Mar 11, 2025
18 of 22 checks passed
@matthias-springer matthias-springer deleted the users/matthias-springer/memref_copy_verification branch March 11, 2025 12:20
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants