Skip to content

Commit 7bb852c

Browse files
[mlir][memref] Add runtime verification for memref.copy
1 parent c37848a commit 7bb852c

File tree

2 files changed

+76
-0
lines changed

2 files changed

+76
-0
lines changed

mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,53 @@ struct CastOpInterface
182182
}
183183
};
184184

185+
struct CopyOpInterface
186+
: public RuntimeVerifiableOpInterface::ExternalModel<CopyOpInterface,
187+
CopyOp> {
188+
void generateRuntimeVerification(Operation *op, OpBuilder &builder,
189+
Location loc) const {
190+
auto copyOp = cast<CopyOp>(op);
191+
BaseMemRefType sourceType = copyOp.getSource().getType();
192+
BaseMemRefType targetType = copyOp.getTarget().getType();
193+
auto rankedSourceType = dyn_cast<MemRefType>(sourceType);
194+
auto rankedTargetType = dyn_cast<MemRefType>(targetType);
195+
196+
// TODO: Verification for unranked memrefs is not supported yet.
197+
if (!rankedSourceType || !rankedTargetType)
198+
return;
199+
200+
assert(sourceType.getRank() == targetType.getRank() && "rank mismatch");
201+
for (int64_t i = 0, e = sourceType.getRank(); i < e; ++i) {
202+
// Fully static dimensions in both source and target operand are already
203+
// verified by the op verifier.
204+
if (!rankedSourceType.isDynamicDim(i) &&
205+
!rankedTargetType.isDynamicDim(i))
206+
continue;
207+
Value sourceDim;
208+
if (rankedSourceType.isDynamicDim(i)) {
209+
sourceDim = builder.create<DimOp>(loc, copyOp.getSource(), i);
210+
} else {
211+
sourceDim = builder.create<arith::ConstantIndexOp>(
212+
loc, rankedSourceType.getDimSize(i));
213+
}
214+
Value targetDim;
215+
if (rankedTargetType.isDynamicDim(i)) {
216+
targetDim = builder.create<DimOp>(loc, copyOp.getTarget(), i);
217+
} else {
218+
targetDim = builder.create<arith::ConstantIndexOp>(
219+
loc, rankedTargetType.getDimSize(i));
220+
}
221+
Value sameDimSize = builder.create<arith::CmpIOp>(
222+
loc, arith::CmpIPredicate::eq, sourceDim, targetDim);
223+
builder.create<cf::AssertOp>(
224+
loc, sameDimSize,
225+
RuntimeVerifiableOpInterface::generateErrorMessage(
226+
op, "size of " + std::to_string(i) +
227+
"-th source/target dim does not match"));
228+
}
229+
}
230+
};
231+
185232
struct DimOpInterface
186233
: public RuntimeVerifiableOpInterface::ExternalModel<DimOpInterface,
187234
DimOp> {
@@ -383,6 +430,7 @@ void mlir::memref::registerRuntimeVerifiableOpInterfaceExternalModels(
383430
AssumeAlignmentOp::attachInterface<AssumeAlignmentOpInterface>(*ctx);
384431
AtomicRMWOp::attachInterface<LoadStoreOpInterface<AtomicRMWOp>>(*ctx);
385432
CastOp::attachInterface<CastOpInterface>(*ctx);
433+
CopyOp::attachInterface<CopyOpInterface>(*ctx);
386434
DimOp::attachInterface<DimOpInterface>(*ctx);
387435
ExpandShapeOp::attachInterface<ExpandShapeOpInterface>(*ctx);
388436
GenericAtomicRMWOp::attachInterface<
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
// RUN: mlir-opt %s -generate-runtime-verification \
2+
// RUN: -expand-strided-metadata \
3+
// RUN: -test-cf-assert \
4+
// RUN: -convert-to-llvm | \
5+
// RUN: mlir-runner -e main -entry-point-result=void \
6+
// RUN: -shared-libs=%mlir_runner_utils 2>&1 | \
7+
// RUN: FileCheck %s
8+
9+
// Put memref.copy in a function, otherwise the memref.cast may fold.
10+
func.func @memcpy_helper(%src: memref<?xf32>, %dest: memref<?xf32>) {
11+
memref.copy %src, %dest : memref<?xf32> to memref<?xf32>
12+
return
13+
}
14+
15+
func.func @main() {
16+
%alloca1 = memref.alloca() : memref<4xf32>
17+
%alloca2 = memref.alloca() : memref<5xf32>
18+
%cast1 = memref.cast %alloca1 : memref<4xf32> to memref<?xf32>
19+
%cast2 = memref.cast %alloca2 : memref<5xf32> to memref<?xf32>
20+
21+
// CHECK: ERROR: Runtime op verification failed
22+
// CHECK-NEXT: "memref.copy"(%{{.*}}, %{{.*}}) : (memref<?xf32>, memref<?xf32>) -> ()
23+
// CHECK-NEXT: ^ size of 0-th source/target dim does not match
24+
// CHECK-NEXT: Location: loc({{.*}})
25+
call @memcpy_helper(%cast1, %cast2) : (memref<?xf32>, memref<?xf32>) -> ()
26+
27+
return
28+
}

0 commit comments

Comments
 (0)