Skip to content

Commit 1b455df

Browse files
[mlir][memref] Add runtime verification for memref.copy (#130437)
Implement runtime op verification for `memref.copy`. Only ranked memrefs are verified at the moment.
1 parent dbbadfd commit 1b455df

File tree

2 files changed

+73
-0
lines changed

2 files changed

+73
-0
lines changed

mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,50 @@ struct CastOpInterface
128128
}
129129
};
130130

131+
struct CopyOpInterface
132+
: public RuntimeVerifiableOpInterface::ExternalModel<CopyOpInterface,
133+
CopyOp> {
134+
void generateRuntimeVerification(Operation *op, OpBuilder &builder,
135+
Location loc) const {
136+
auto copyOp = cast<CopyOp>(op);
137+
BaseMemRefType sourceType = copyOp.getSource().getType();
138+
BaseMemRefType targetType = copyOp.getTarget().getType();
139+
auto rankedSourceType = dyn_cast<MemRefType>(sourceType);
140+
auto rankedTargetType = dyn_cast<MemRefType>(targetType);
141+
142+
// TODO: Verification for unranked memrefs is not supported yet.
143+
if (!rankedSourceType || !rankedTargetType)
144+
return;
145+
146+
assert(sourceType.getRank() == targetType.getRank() && "rank mismatch");
147+
for (int64_t i = 0, e = sourceType.getRank(); i < e; ++i) {
148+
// Fully static dimensions in both source and target operand are already
149+
// verified by the op verifier.
150+
if (!rankedSourceType.isDynamicDim(i) &&
151+
!rankedTargetType.isDynamicDim(i))
152+
continue;
153+
auto getDimSize = [&](Value memRef, MemRefType type,
154+
int64_t dim) -> Value {
155+
return type.isDynamicDim(dim)
156+
? builder.create<DimOp>(loc, memRef, dim).getResult()
157+
: builder
158+
.create<arith::ConstantIndexOp>(loc,
159+
type.getDimSize(dim))
160+
.getResult();
161+
};
162+
Value sourceDim = getDimSize(copyOp.getSource(), rankedSourceType, i);
163+
Value targetDim = getDimSize(copyOp.getTarget(), rankedTargetType, i);
164+
Value sameDimSize = builder.create<arith::CmpIOp>(
165+
loc, arith::CmpIPredicate::eq, sourceDim, targetDim);
166+
builder.create<cf::AssertOp>(
167+
loc, sameDimSize,
168+
RuntimeVerifiableOpInterface::generateErrorMessage(
169+
op, "size of " + std::to_string(i) +
170+
"-th source/target dim does not match"));
171+
}
172+
}
173+
};
174+
131175
/// Verifies that the indices on load/store ops are in-bounds of the memref's
132176
/// index space: 0 <= index#i < dim#i
133177
template <typename LoadStoreOp>
@@ -335,6 +379,7 @@ void mlir::memref::registerRuntimeVerifiableOpInterfaceExternalModels(
335379
DialectRegistry &registry) {
336380
registry.addExtension(+[](MLIRContext *ctx, memref::MemRefDialect *dialect) {
337381
CastOp::attachInterface<CastOpInterface>(*ctx);
382+
CopyOp::attachInterface<CopyOpInterface>(*ctx);
338383
ExpandShapeOp::attachInterface<ExpandShapeOpInterface>(*ctx);
339384
LoadOp::attachInterface<LoadStoreOpInterface<LoadOp>>(*ctx);
340385
ReinterpretCastOp::attachInterface<ReinterpretCastOpInterface>(*ctx);
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
// RUN: mlir-opt %s -generate-runtime-verification \
2+
// RUN: -expand-strided-metadata \
3+
// RUN: -test-cf-assert \
4+
// RUN: -convert-to-llvm | \
5+
// RUN: mlir-runner -e main -entry-point-result=void \
6+
// RUN: -shared-libs=%mlir_runner_utils 2>&1 | \
7+
// RUN: FileCheck %s
8+
9+
// Put memref.copy in a function, otherwise the memref.cast may fold.
10+
func.func @memcpy_helper(%src: memref<?xf32>, %dest: memref<?xf32>) {
11+
memref.copy %src, %dest : memref<?xf32> to memref<?xf32>
12+
return
13+
}
14+
15+
func.func @main() {
16+
%alloca1 = memref.alloca() : memref<4xf32>
17+
%alloca2 = memref.alloca() : memref<5xf32>
18+
%cast1 = memref.cast %alloca1 : memref<4xf32> to memref<?xf32>
19+
%cast2 = memref.cast %alloca2 : memref<5xf32> to memref<?xf32>
20+
21+
// CHECK: ERROR: Runtime op verification failed
22+
// CHECK-NEXT: "memref.copy"(%{{.*}}, %{{.*}}) : (memref<?xf32>, memref<?xf32>) -> ()
23+
// CHECK-NEXT: ^ size of 0-th source/target dim does not match
24+
// CHECK-NEXT: Location: loc({{.*}})
25+
call @memcpy_helper(%cast1, %cast2) : (memref<?xf32>, memref<?xf32>) -> ()
26+
27+
return
28+
}

0 commit comments

Comments
 (0)