diff --git a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp index 53a618d787333..3fd561de3b5e6 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp @@ -23,6 +23,18 @@ using namespace mlir; namespace mlir { namespace memref { namespace { +/// Generate a runtime check for lb <= value < ub. +Value generateInBoundsCheck(OpBuilder &builder, Location loc, Value value, + Value lb, Value ub) { + Value inBounds1 = builder.createOrFold( + loc, arith::CmpIPredicate::sge, value, lb); + Value inBounds2 = builder.createOrFold( + loc, arith::CmpIPredicate::slt, value, ub); + Value inBounds = + builder.createOrFold(loc, inBounds1, inBounds2); + return inBounds; +} + struct CastOpInterface : public RuntimeVerifiableOpInterface::ExternalModel { @@ -172,6 +184,21 @@ struct CopyOpInterface } }; +struct DimOpInterface + : public RuntimeVerifiableOpInterface::ExternalModel { + void generateRuntimeVerification(Operation *op, OpBuilder &builder, + Location loc) const { + auto dimOp = cast(op); + Value rank = builder.create(loc, dimOp.getSource()); + Value zero = builder.create(loc, 0); + builder.create( + loc, generateInBoundsCheck(builder, loc, dimOp.getIndex(), zero, rank), + RuntimeVerifiableOpInterface::generateErrorMessage( + op, "index is out of bounds")); + } +}; + /// Verifies that the indices on load/store ops are in-bounds of the memref's /// index space: 0 <= index#i < dim#i template @@ -192,19 +219,12 @@ struct LoadStoreOpInterface auto zero = builder.create(loc, 0); Value assertCond; for (auto i : llvm::seq(0, rank)) { - auto index = indices[i]; - - auto dimOp = builder.createOrFold(loc, memref, i); - - auto geLow = builder.createOrFold( - loc, arith::CmpIPredicate::sge, index, zero); - auto ltHigh = builder.createOrFold( - loc, arith::CmpIPredicate::slt, index, dimOp); - auto andOp = builder.createOrFold(loc, geLow, ltHigh); - + Value dimOp = builder.createOrFold(loc, memref, i); + Value inBounds = + generateInBoundsCheck(builder, loc, indices[i], zero, dimOp); assertCond = - i > 0 ? builder.createOrFold(loc, assertCond, andOp) - : andOp; + i > 0 ? builder.createOrFold(loc, assertCond, inBounds) + : inBounds; } builder.create( loc, assertCond, @@ -380,6 +400,7 @@ void mlir::memref::registerRuntimeVerifiableOpInterfaceExternalModels( registry.addExtension(+[](MLIRContext *ctx, memref::MemRefDialect *dialect) { CastOp::attachInterface(*ctx); CopyOp::attachInterface(*ctx); + DimOp::attachInterface(*ctx); ExpandShapeOp::attachInterface(*ctx); LoadOp::attachInterface>(*ctx); ReinterpretCastOp::attachInterface(*ctx); diff --git a/mlir/lib/Transforms/GenerateRuntimeVerification.cpp b/mlir/lib/Transforms/GenerateRuntimeVerification.cpp index 62db9ce1316ae..a40bc2b3272fc 100644 --- a/mlir/lib/Transforms/GenerateRuntimeVerification.cpp +++ b/mlir/lib/Transforms/GenerateRuntimeVerification.cpp @@ -28,11 +28,19 @@ struct GenerateRuntimeVerificationPass } // namespace void GenerateRuntimeVerificationPass::runOnOperation() { + // The implementation of the RuntimeVerifiableOpInterface may create ops that + // can be verified. We don't want to generate verification for IR that + // performs verification, so gather all runtime-verifiable ops first. + SmallVector ops; getOperation()->walk([&](RuntimeVerifiableOpInterface verifiableOp) { - OpBuilder builder(getOperation()->getContext()); + ops.push_back(verifiableOp); + }); + + OpBuilder builder(getOperation()->getContext()); + for (RuntimeVerifiableOpInterface verifiableOp : ops) { builder.setInsertionPoint(verifiableOp); verifiableOp.generateRuntimeVerification(builder, verifiableOp.getLoc()); - }); + }; } std::unique_ptr mlir::createGenerateRuntimeVerificationPass() { diff --git a/mlir/test/Integration/Dialect/MemRef/cast-runtime-verification.mlir b/mlir/test/Integration/Dialect/MemRef/cast-runtime-verification.mlir index b101a875154ff..8b6308e9c1939 100644 --- a/mlir/test/Integration/Dialect/MemRef/cast-runtime-verification.mlir +++ b/mlir/test/Integration/Dialect/MemRef/cast-runtime-verification.mlir @@ -1,8 +1,7 @@ -// RUN: mlir-opt %s -generate-runtime-verification -finalize-memref-to-llvm \ +// RUN: mlir-opt %s -generate-runtime-verification \ // RUN: -test-cf-assert \ -// RUN: -convert-func-to-llvm \ -// RUN: -convert-arith-to-llvm \ -// RUN: -reconcile-unrealized-casts | \ +// RUN: -expand-strided-metadata \ +// RUN: -convert-to-llvm | \ // RUN: mlir-runner -e main -entry-point-result=void \ // RUN: -shared-libs=%mlir_runner_utils 2>&1 | \ // RUN: FileCheck %s diff --git a/mlir/test/Integration/Dialect/MemRef/dim-runtime-verification.mlir b/mlir/test/Integration/Dialect/MemRef/dim-runtime-verification.mlir new file mode 100644 index 0000000000000..2e3f271743c93 --- /dev/null +++ b/mlir/test/Integration/Dialect/MemRef/dim-runtime-verification.mlir @@ -0,0 +1,20 @@ +// RUN: mlir-opt %s -generate-runtime-verification \ +// RUN: -expand-strided-metadata \ +// RUN: -test-cf-assert \ +// RUN: -convert-to-llvm | \ +// RUN: mlir-runner -e main -entry-point-result=void \ +// RUN: -shared-libs=%mlir_runner_utils 2>&1 | \ +// RUN: FileCheck %s + +func.func @main() { + %c4 = arith.constant 4 : index + %alloca = memref.alloca() : memref<1xf32> + + // CHECK: ERROR: Runtime op verification failed + // CHECK-NEXT: "memref.dim"(%{{.*}}, %{{.*}}) : (memref<1xf32>, index) -> index + // CHECK-NEXT: ^ index is out of bounds + // CHECK-NEXT: Location: loc({{.*}}) + %dim = memref.dim %alloca, %c4 : memref<1xf32> + + return +} diff --git a/mlir/test/Integration/Dialect/MemRef/load-runtime-verification.mlir b/mlir/test/Integration/Dialect/MemRef/load-runtime-verification.mlir index d6c5d6da0041e..b87e5bdf0970c 100644 --- a/mlir/test/Integration/Dialect/MemRef/load-runtime-verification.mlir +++ b/mlir/test/Integration/Dialect/MemRef/load-runtime-verification.mlir @@ -1,10 +1,8 @@ // RUN: mlir-opt %s -generate-runtime-verification \ -// RUN: -expand-strided-metadata \ -// RUN: -finalize-memref-to-llvm \ // RUN: -test-cf-assert \ -// RUN: -convert-func-to-llvm \ -// RUN: -convert-arith-to-llvm \ -// RUN: -reconcile-unrealized-casts | \ +// RUN: -expand-strided-metadata \ +// RUN: -lower-affine \ +// RUN: -convert-to-llvm | \ // RUN: mlir-runner -e main -entry-point-result=void \ // RUN: -shared-libs=%mlir_runner_utils 2>&1 | \ // RUN: FileCheck %s diff --git a/mlir/test/Integration/Dialect/MemRef/reinterpret-cast-runtime-verification.mlir b/mlir/test/Integration/Dialect/MemRef/reinterpret-cast-runtime-verification.mlir index 9fea48bdfc07d..601a53f4b5cd9 100644 --- a/mlir/test/Integration/Dialect/MemRef/reinterpret-cast-runtime-verification.mlir +++ b/mlir/test/Integration/Dialect/MemRef/reinterpret-cast-runtime-verification.mlir @@ -1,10 +1,8 @@ // RUN: mlir-opt %s -generate-runtime-verification \ -// RUN: -lower-affine \ -// RUN: -finalize-memref-to-llvm \ // RUN: -test-cf-assert \ -// RUN: -convert-func-to-llvm \ -// RUN: -convert-arith-to-llvm \ -// RUN: -reconcile-unrealized-casts | \ +// RUN: -expand-strided-metadata \ +// RUN: -lower-affine \ +// RUN: -convert-to-llvm | \ // RUN: mlir-runner -e main -entry-point-result=void \ // RUN: -shared-libs=%mlir_runner_utils 2>&1 | \ // RUN: FileCheck %s diff --git a/mlir/test/Integration/Dialect/MemRef/subview-runtime-verification.mlir b/mlir/test/Integration/Dialect/MemRef/subview-runtime-verification.mlir index 66474e9c4ae37..3cac37a082c30 100644 --- a/mlir/test/Integration/Dialect/MemRef/subview-runtime-verification.mlir +++ b/mlir/test/Integration/Dialect/MemRef/subview-runtime-verification.mlir @@ -1,11 +1,8 @@ // RUN: mlir-opt %s -generate-runtime-verification \ +// RUN: -test-cf-assert \ // RUN: -expand-strided-metadata \ // RUN: -lower-affine \ -// RUN: -finalize-memref-to-llvm \ -// RUN: -test-cf-assert \ -// RUN: -convert-func-to-llvm \ -// RUN: -convert-arith-to-llvm \ -// RUN: -reconcile-unrealized-casts | \ +// RUN: -convert-to-llvm | \ // RUN: mlir-runner -e main -entry-point-result=void \ // RUN: -shared-libs=%mlir_runner_utils 2>&1 | \ // RUN: FileCheck %s