Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 45 additions & 0 deletions mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,50 @@ struct CastOpInterface
}
};

struct CopyOpInterface
: public RuntimeVerifiableOpInterface::ExternalModel<CopyOpInterface,
CopyOp> {
void generateRuntimeVerification(Operation *op, OpBuilder &builder,
Location loc) const {
auto copyOp = cast<CopyOp>(op);
BaseMemRefType sourceType = copyOp.getSource().getType();
BaseMemRefType targetType = copyOp.getTarget().getType();
auto rankedSourceType = dyn_cast<MemRefType>(sourceType);
auto rankedTargetType = dyn_cast<MemRefType>(targetType);

// TODO: Verification for unranked memrefs is not supported yet.
if (!rankedSourceType || !rankedTargetType)
return;

assert(sourceType.getRank() == targetType.getRank() && "rank mismatch");
for (int64_t i = 0, e = sourceType.getRank(); i < e; ++i) {
// Fully static dimensions in both source and target operand are already
// verified by the op verifier.
if (!rankedSourceType.isDynamicDim(i) &&
!rankedTargetType.isDynamicDim(i))
continue;
auto getDimSize = [&](Value memRef, MemRefType type,
int64_t dim) -> Value {
return type.isDynamicDim(dim)
? builder.create<DimOp>(loc, memRef, dim).getResult()
: builder
.create<arith::ConstantIndexOp>(loc,
type.getDimSize(dim))
.getResult();
};
Value sourceDim = getDimSize(copyOp.getSource(), rankedSourceType, i);
Value targetDim = getDimSize(copyOp.getTarget(), rankedTargetType, i);
Value sameDimSize = builder.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::eq, sourceDim, targetDim);
builder.create<cf::AssertOp>(
loc, sameDimSize,
RuntimeVerifiableOpInterface::generateErrorMessage(
op, "size of " + std::to_string(i) +
"-th source/target dim does not match"));
}
}
};

/// Verifies that the indices on load/store ops are in-bounds of the memref's
/// index space: 0 <= index#i < dim#i
template <typename LoadStoreOp>
Expand Down Expand Up @@ -335,6 +379,7 @@ void mlir::memref::registerRuntimeVerifiableOpInterfaceExternalModels(
DialectRegistry &registry) {
registry.addExtension(+[](MLIRContext *ctx, memref::MemRefDialect *dialect) {
CastOp::attachInterface<CastOpInterface>(*ctx);
CopyOp::attachInterface<CopyOpInterface>(*ctx);
ExpandShapeOp::attachInterface<ExpandShapeOpInterface>(*ctx);
LoadOp::attachInterface<LoadStoreOpInterface<LoadOp>>(*ctx);
ReinterpretCastOp::attachInterface<ReinterpretCastOpInterface>(*ctx);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
// RUN: mlir-opt %s -generate-runtime-verification \
// RUN: -expand-strided-metadata \
// RUN: -test-cf-assert \
// RUN: -convert-to-llvm | \
// RUN: mlir-runner -e main -entry-point-result=void \
// RUN: -shared-libs=%mlir_runner_utils 2>&1 | \
// RUN: FileCheck %s

// Put memref.copy in a function, otherwise the memref.cast may fold.
func.func @memcpy_helper(%src: memref<?xf32>, %dest: memref<?xf32>) {
memref.copy %src, %dest : memref<?xf32> to memref<?xf32>
return
}

func.func @main() {
%alloca1 = memref.alloca() : memref<4xf32>
%alloca2 = memref.alloca() : memref<5xf32>
%cast1 = memref.cast %alloca1 : memref<4xf32> to memref<?xf32>
%cast2 = memref.cast %alloca2 : memref<5xf32> to memref<?xf32>

// CHECK: ERROR: Runtime op verification failed
// CHECK-NEXT: "memref.copy"(%{{.*}}, %{{.*}}) : (memref<?xf32>, memref<?xf32>) -> ()
// CHECK-NEXT: ^ size of 0-th source/target dim does not match
// CHECK-NEXT: Location: loc({{.*}})
call @memcpy_helper(%cast1, %cast2) : (memref<?xf32>, memref<?xf32>) -> ()

return
}