-
Notifications
You must be signed in to change notification settings - Fork 15.3k
[mlir][memref] Add runtime verification for memref.copy
#130437
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][memref] Add runtime verification for memref.copy
#130437
Conversation
|
@llvm/pr-subscribers-mlir Author: Matthias Springer (matthias-springer) ChangesImplement runtime op verification for Full diff: https://github.com/llvm/llvm-project/pull/130437.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
index ceea27a35a225..c604af249ba2e 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
@@ -182,6 +182,53 @@ struct CastOpInterface
}
};
+struct CopyOpInterface
+ : public RuntimeVerifiableOpInterface::ExternalModel<CopyOpInterface,
+ CopyOp> {
+ void generateRuntimeVerification(Operation *op, OpBuilder &builder,
+ Location loc) const {
+ auto copyOp = cast<CopyOp>(op);
+ BaseMemRefType sourceType = copyOp.getSource().getType();
+ BaseMemRefType targetType = copyOp.getTarget().getType();
+ auto rankedSourceType = dyn_cast<MemRefType>(sourceType);
+ auto rankedTargetType = dyn_cast<MemRefType>(targetType);
+
+ // TODO: Verification for unranked memrefs is not supported yet.
+ if (!rankedSourceType || !rankedTargetType)
+ return;
+
+ assert(sourceType.getRank() == targetType.getRank() && "rank mismatch");
+ for (int64_t i = 0, e = sourceType.getRank(); i < e; ++i) {
+ // Fully static dimensions in both source and target operand are already
+ // verified by the op verifier.
+ if (!rankedSourceType.isDynamicDim(i) &&
+ !rankedTargetType.isDynamicDim(i))
+ continue;
+ Value sourceDim;
+ if (rankedSourceType.isDynamicDim(i)) {
+ sourceDim = builder.create<DimOp>(loc, copyOp.getSource(), i);
+ } else {
+ sourceDim = builder.create<arith::ConstantIndexOp>(
+ loc, rankedSourceType.getDimSize(i));
+ }
+ Value targetDim;
+ if (rankedTargetType.isDynamicDim(i)) {
+ targetDim = builder.create<DimOp>(loc, copyOp.getTarget(), i);
+ } else {
+ targetDim = builder.create<arith::ConstantIndexOp>(
+ loc, rankedTargetType.getDimSize(i));
+ }
+ Value sameDimSize = builder.create<arith::CmpIOp>(
+ loc, arith::CmpIPredicate::eq, sourceDim, targetDim);
+ builder.create<cf::AssertOp>(
+ loc, sameDimSize,
+ RuntimeVerifiableOpInterface::generateErrorMessage(
+ op, "size of " + std::to_string(i) +
+ "-th source/target dim does not match"));
+ }
+ }
+};
+
struct DimOpInterface
: public RuntimeVerifiableOpInterface::ExternalModel<DimOpInterface,
DimOp> {
@@ -383,6 +430,7 @@ void mlir::memref::registerRuntimeVerifiableOpInterfaceExternalModels(
AssumeAlignmentOp::attachInterface<AssumeAlignmentOpInterface>(*ctx);
AtomicRMWOp::attachInterface<LoadStoreOpInterface<AtomicRMWOp>>(*ctx);
CastOp::attachInterface<CastOpInterface>(*ctx);
+ CopyOp::attachInterface<CopyOpInterface>(*ctx);
DimOp::attachInterface<DimOpInterface>(*ctx);
ExpandShapeOp::attachInterface<ExpandShapeOpInterface>(*ctx);
GenericAtomicRMWOp::attachInterface<
diff --git a/mlir/test/Integration/Dialect/MemRef/copy-runtime-verification.mlir b/mlir/test/Integration/Dialect/MemRef/copy-runtime-verification.mlir
new file mode 100644
index 0000000000000..95b9db2832cee
--- /dev/null
+++ b/mlir/test/Integration/Dialect/MemRef/copy-runtime-verification.mlir
@@ -0,0 +1,28 @@
+// RUN: mlir-opt %s -generate-runtime-verification \
+// RUN: -expand-strided-metadata \
+// RUN: -test-cf-assert \
+// RUN: -convert-to-llvm | \
+// RUN: mlir-runner -e main -entry-point-result=void \
+// RUN: -shared-libs=%mlir_runner_utils 2>&1 | \
+// RUN: FileCheck %s
+
+// Put memref.copy in a function, otherwise the memref.cast may fold.
+func.func @memcpy_helper(%src: memref<?xf32>, %dest: memref<?xf32>) {
+ memref.copy %src, %dest : memref<?xf32> to memref<?xf32>
+ return
+}
+
+func.func @main() {
+ %alloca1 = memref.alloca() : memref<4xf32>
+ %alloca2 = memref.alloca() : memref<5xf32>
+ %cast1 = memref.cast %alloca1 : memref<4xf32> to memref<?xf32>
+ %cast2 = memref.cast %alloca2 : memref<5xf32> to memref<?xf32>
+
+ // CHECK: ERROR: Runtime op verification failed
+ // CHECK-NEXT: "memref.copy"(%{{.*}}, %{{.*}}) : (memref<?xf32>, memref<?xf32>) -> ()
+ // CHECK-NEXT: ^ size of 0-th source/target dim does not match
+ // CHECK-NEXT: Location: loc({{.*}})
+ call @memcpy_helper(%cast1, %cast2) : (memref<?xf32>, memref<?xf32>) -> ()
+
+ return
+}
|
|
@llvm/pr-subscribers-mlir-memref Author: Matthias Springer (matthias-springer) ChangesImplement runtime op verification for Full diff: https://github.com/llvm/llvm-project/pull/130437.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
index ceea27a35a225..c604af249ba2e 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
@@ -182,6 +182,53 @@ struct CastOpInterface
}
};
+struct CopyOpInterface
+ : public RuntimeVerifiableOpInterface::ExternalModel<CopyOpInterface,
+ CopyOp> {
+ void generateRuntimeVerification(Operation *op, OpBuilder &builder,
+ Location loc) const {
+ auto copyOp = cast<CopyOp>(op);
+ BaseMemRefType sourceType = copyOp.getSource().getType();
+ BaseMemRefType targetType = copyOp.getTarget().getType();
+ auto rankedSourceType = dyn_cast<MemRefType>(sourceType);
+ auto rankedTargetType = dyn_cast<MemRefType>(targetType);
+
+ // TODO: Verification for unranked memrefs is not supported yet.
+ if (!rankedSourceType || !rankedTargetType)
+ return;
+
+ assert(sourceType.getRank() == targetType.getRank() && "rank mismatch");
+ for (int64_t i = 0, e = sourceType.getRank(); i < e; ++i) {
+ // Fully static dimensions in both source and target operand are already
+ // verified by the op verifier.
+ if (!rankedSourceType.isDynamicDim(i) &&
+ !rankedTargetType.isDynamicDim(i))
+ continue;
+ Value sourceDim;
+ if (rankedSourceType.isDynamicDim(i)) {
+ sourceDim = builder.create<DimOp>(loc, copyOp.getSource(), i);
+ } else {
+ sourceDim = builder.create<arith::ConstantIndexOp>(
+ loc, rankedSourceType.getDimSize(i));
+ }
+ Value targetDim;
+ if (rankedTargetType.isDynamicDim(i)) {
+ targetDim = builder.create<DimOp>(loc, copyOp.getTarget(), i);
+ } else {
+ targetDim = builder.create<arith::ConstantIndexOp>(
+ loc, rankedTargetType.getDimSize(i));
+ }
+ Value sameDimSize = builder.create<arith::CmpIOp>(
+ loc, arith::CmpIPredicate::eq, sourceDim, targetDim);
+ builder.create<cf::AssertOp>(
+ loc, sameDimSize,
+ RuntimeVerifiableOpInterface::generateErrorMessage(
+ op, "size of " + std::to_string(i) +
+ "-th source/target dim does not match"));
+ }
+ }
+};
+
struct DimOpInterface
: public RuntimeVerifiableOpInterface::ExternalModel<DimOpInterface,
DimOp> {
@@ -383,6 +430,7 @@ void mlir::memref::registerRuntimeVerifiableOpInterfaceExternalModels(
AssumeAlignmentOp::attachInterface<AssumeAlignmentOpInterface>(*ctx);
AtomicRMWOp::attachInterface<LoadStoreOpInterface<AtomicRMWOp>>(*ctx);
CastOp::attachInterface<CastOpInterface>(*ctx);
+ CopyOp::attachInterface<CopyOpInterface>(*ctx);
DimOp::attachInterface<DimOpInterface>(*ctx);
ExpandShapeOp::attachInterface<ExpandShapeOpInterface>(*ctx);
GenericAtomicRMWOp::attachInterface<
diff --git a/mlir/test/Integration/Dialect/MemRef/copy-runtime-verification.mlir b/mlir/test/Integration/Dialect/MemRef/copy-runtime-verification.mlir
new file mode 100644
index 0000000000000..95b9db2832cee
--- /dev/null
+++ b/mlir/test/Integration/Dialect/MemRef/copy-runtime-verification.mlir
@@ -0,0 +1,28 @@
+// RUN: mlir-opt %s -generate-runtime-verification \
+// RUN: -expand-strided-metadata \
+// RUN: -test-cf-assert \
+// RUN: -convert-to-llvm | \
+// RUN: mlir-runner -e main -entry-point-result=void \
+// RUN: -shared-libs=%mlir_runner_utils 2>&1 | \
+// RUN: FileCheck %s
+
+// Put memref.copy in a function, otherwise the memref.cast may fold.
+func.func @memcpy_helper(%src: memref<?xf32>, %dest: memref<?xf32>) {
+ memref.copy %src, %dest : memref<?xf32> to memref<?xf32>
+ return
+}
+
+func.func @main() {
+ %alloca1 = memref.alloca() : memref<4xf32>
+ %alloca2 = memref.alloca() : memref<5xf32>
+ %cast1 = memref.cast %alloca1 : memref<4xf32> to memref<?xf32>
+ %cast2 = memref.cast %alloca2 : memref<5xf32> to memref<?xf32>
+
+ // CHECK: ERROR: Runtime op verification failed
+ // CHECK-NEXT: "memref.copy"(%{{.*}}, %{{.*}}) : (memref<?xf32>, memref<?xf32>) -> ()
+ // CHECK-NEXT: ^ size of 0-th source/target dim does not match
+ // CHECK-NEXT: Location: loc({{.*}})
+ call @memcpy_helper(%cast1, %cast2) : (memref<?xf32>, memref<?xf32>) -> ()
+
+ return
+}
|
7bb852c to
2a3b442
Compare
Implement runtime op verification for
memref.copy. Only ranked memrefs are verified at the moment.