Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 33 additions & 12 deletions mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,18 @@ using namespace mlir;
namespace mlir {
namespace memref {
namespace {
/// Generate a runtime check for lb <= value < ub.
Value generateInBoundsCheck(OpBuilder &builder, Location loc, Value value,
Value lb, Value ub) {
Value inBounds1 = builder.createOrFold<arith::CmpIOp>(
loc, arith::CmpIPredicate::sge, value, lb);
Value inBounds2 = builder.createOrFold<arith::CmpIOp>(
loc, arith::CmpIPredicate::slt, value, ub);
Value inBounds =
builder.createOrFold<arith::AndIOp>(loc, inBounds1, inBounds2);
return inBounds;
}

struct CastOpInterface
: public RuntimeVerifiableOpInterface::ExternalModel<CastOpInterface,
CastOp> {
Expand Down Expand Up @@ -172,6 +184,21 @@ struct CopyOpInterface
}
};

struct DimOpInterface
: public RuntimeVerifiableOpInterface::ExternalModel<DimOpInterface,
DimOp> {
void generateRuntimeVerification(Operation *op, OpBuilder &builder,
Location loc) const {
auto dimOp = cast<DimOp>(op);
Value rank = builder.create<RankOp>(loc, dimOp.getSource());
Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
builder.create<cf::AssertOp>(
loc, generateInBoundsCheck(builder, loc, dimOp.getIndex(), zero, rank),
RuntimeVerifiableOpInterface::generateErrorMessage(
op, "index is out of bounds"));
}
};

/// Verifies that the indices on load/store ops are in-bounds of the memref's
/// index space: 0 <= index#i < dim#i
template <typename LoadStoreOp>
Expand All @@ -192,19 +219,12 @@ struct LoadStoreOpInterface
auto zero = builder.create<arith::ConstantIndexOp>(loc, 0);
Value assertCond;
for (auto i : llvm::seq<int64_t>(0, rank)) {
auto index = indices[i];

auto dimOp = builder.createOrFold<memref::DimOp>(loc, memref, i);

auto geLow = builder.createOrFold<arith::CmpIOp>(
loc, arith::CmpIPredicate::sge, index, zero);
auto ltHigh = builder.createOrFold<arith::CmpIOp>(
loc, arith::CmpIPredicate::slt, index, dimOp);
auto andOp = builder.createOrFold<arith::AndIOp>(loc, geLow, ltHigh);

Value dimOp = builder.createOrFold<memref::DimOp>(loc, memref, i);
Value inBounds =
generateInBoundsCheck(builder, loc, indices[i], zero, dimOp);
assertCond =
i > 0 ? builder.createOrFold<arith::AndIOp>(loc, assertCond, andOp)
: andOp;
i > 0 ? builder.createOrFold<arith::AndIOp>(loc, assertCond, inBounds)
: inBounds;
}
builder.create<cf::AssertOp>(
loc, assertCond,
Expand Down Expand Up @@ -380,6 +400,7 @@ void mlir::memref::registerRuntimeVerifiableOpInterfaceExternalModels(
registry.addExtension(+[](MLIRContext *ctx, memref::MemRefDialect *dialect) {
CastOp::attachInterface<CastOpInterface>(*ctx);
CopyOp::attachInterface<CopyOpInterface>(*ctx);
DimOp::attachInterface<DimOpInterface>(*ctx);
ExpandShapeOp::attachInterface<ExpandShapeOpInterface>(*ctx);
LoadOp::attachInterface<LoadStoreOpInterface<LoadOp>>(*ctx);
ReinterpretCastOp::attachInterface<ReinterpretCastOpInterface>(*ctx);
Expand Down
12 changes: 10 additions & 2 deletions mlir/lib/Transforms/GenerateRuntimeVerification.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,19 @@ struct GenerateRuntimeVerificationPass
} // namespace

void GenerateRuntimeVerificationPass::runOnOperation() {
// The implementation of the RuntimeVerifiableOpInterface may create ops that
// can be verified. We don't want to generate verification for IR that
// performs verification, so gather all runtime-verifiable ops first.
SmallVector<RuntimeVerifiableOpInterface> ops;
getOperation()->walk([&](RuntimeVerifiableOpInterface verifiableOp) {
OpBuilder builder(getOperation()->getContext());
ops.push_back(verifiableOp);
});

OpBuilder builder(getOperation()->getContext());
for (RuntimeVerifiableOpInterface verifiableOp : ops) {
builder.setInsertionPoint(verifiableOp);
verifiableOp.generateRuntimeVerification(builder, verifiableOp.getLoc());
});
};
}

std::unique_ptr<Pass> mlir::createGenerateRuntimeVerificationPass() {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
// RUN: mlir-opt %s -generate-runtime-verification -finalize-memref-to-llvm \
// RUN: mlir-opt %s -generate-runtime-verification \
// RUN: -test-cf-assert \
// RUN: -convert-func-to-llvm \
// RUN: -convert-arith-to-llvm \
// RUN: -reconcile-unrealized-casts | \
// RUN: -expand-strided-metadata \
// RUN: -convert-to-llvm | \
// RUN: mlir-runner -e main -entry-point-result=void \
// RUN: -shared-libs=%mlir_runner_utils 2>&1 | \
// RUN: FileCheck %s
Expand Down
20 changes: 20 additions & 0 deletions mlir/test/Integration/Dialect/MemRef/dim-runtime-verification.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
// RUN: mlir-opt %s -generate-runtime-verification \
// RUN: -expand-strided-metadata \
// RUN: -test-cf-assert \
// RUN: -convert-to-llvm | \
// RUN: mlir-runner -e main -entry-point-result=void \
// RUN: -shared-libs=%mlir_runner_utils 2>&1 | \
// RUN: FileCheck %s

func.func @main() {
%c4 = arith.constant 4 : index
%alloca = memref.alloca() : memref<1xf32>

// CHECK: ERROR: Runtime op verification failed
// CHECK-NEXT: "memref.dim"(%{{.*}}, %{{.*}}) : (memref<1xf32>, index) -> index
// CHECK-NEXT: ^ index is out of bounds
// CHECK-NEXT: Location: loc({{.*}})
%dim = memref.dim %alloca, %c4 : memref<1xf32>

return
}
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
// RUN: mlir-opt %s -generate-runtime-verification \
// RUN: -expand-strided-metadata \
// RUN: -finalize-memref-to-llvm \
// RUN: -test-cf-assert \
// RUN: -convert-func-to-llvm \
// RUN: -convert-arith-to-llvm \
// RUN: -reconcile-unrealized-casts | \
// RUN: -expand-strided-metadata \
// RUN: -lower-affine \
// RUN: -convert-to-llvm | \
// RUN: mlir-runner -e main -entry-point-result=void \
// RUN: -shared-libs=%mlir_runner_utils 2>&1 | \
// RUN: FileCheck %s
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
// RUN: mlir-opt %s -generate-runtime-verification \
// RUN: -lower-affine \
// RUN: -finalize-memref-to-llvm \
// RUN: -test-cf-assert \
// RUN: -convert-func-to-llvm \
// RUN: -convert-arith-to-llvm \
// RUN: -reconcile-unrealized-casts | \
// RUN: -expand-strided-metadata \
// RUN: -lower-affine \
// RUN: -convert-to-llvm | \
// RUN: mlir-runner -e main -entry-point-result=void \
// RUN: -shared-libs=%mlir_runner_utils 2>&1 | \
// RUN: FileCheck %s
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
// RUN: mlir-opt %s -generate-runtime-verification \
// RUN: -test-cf-assert \
// RUN: -expand-strided-metadata \
// RUN: -lower-affine \
// RUN: -finalize-memref-to-llvm \
// RUN: -test-cf-assert \
// RUN: -convert-func-to-llvm \
// RUN: -convert-arith-to-llvm \
// RUN: -reconcile-unrealized-casts | \
// RUN: -convert-to-llvm | \
// RUN: mlir-runner -e main -entry-point-result=void \
// RUN: -shared-libs=%mlir_runner_utils 2>&1 | \
// RUN: FileCheck %s
Expand Down
Loading