diff --git a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp index cd92026562da9..10d992fa9dc49 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp @@ -255,78 +255,6 @@ struct LoadStoreOpInterface } }; -/// Compute the linear index for the provided strided layout and indices. -Value computeLinearIndex(OpBuilder &builder, Location loc, OpFoldResult offset, - ArrayRef strides, - ArrayRef indices) { - auto [expr, values] = computeLinearIndex(offset, strides, indices); - auto index = - affine::makeComposedFoldedAffineApply(builder, loc, expr, values); - return getValueOrCreateConstantIndexOp(builder, loc, index); -} - -/// Returns two Values representing the bounds of the provided strided layout -/// metadata. The bounds are returned as a half open interval -- [low, high). -std::pair computeLinearBounds(OpBuilder &builder, Location loc, - OpFoldResult offset, - ArrayRef strides, - ArrayRef sizes) { - auto zeros = SmallVector(sizes.size(), 0); - auto indices = getAsIndexOpFoldResult(builder.getContext(), zeros); - auto lowerBound = computeLinearIndex(builder, loc, offset, strides, indices); - auto upperBound = computeLinearIndex(builder, loc, offset, strides, sizes); - return {lowerBound, upperBound}; -} - -/// Returns two Values representing the bounds of the memref. The bounds are -/// returned as a half open interval -- [low, high). -std::pair computeLinearBounds(OpBuilder &builder, Location loc, - TypedValue memref) { - auto runtimeMetadata = builder.create(loc, memref); - auto offset = runtimeMetadata.getConstifiedMixedOffset(); - auto strides = runtimeMetadata.getConstifiedMixedStrides(); - auto sizes = runtimeMetadata.getConstifiedMixedSizes(); - return computeLinearBounds(builder, loc, offset, strides, sizes); -} - -/// Verifies that the linear bounds of a reinterpret_cast op are within the -/// linear bounds of the base memref: low >= baseLow && high <= baseHigh -struct ReinterpretCastOpInterface - : public RuntimeVerifiableOpInterface::ExternalModel< - ReinterpretCastOpInterface, ReinterpretCastOp> { - void generateRuntimeVerification(Operation *op, OpBuilder &builder, - Location loc) const { - auto reinterpretCast = cast(op); - auto baseMemref = reinterpretCast.getSource(); - auto resultMemref = - cast>(reinterpretCast.getResult()); - - builder.setInsertionPointAfter(op); - - // Compute the linear bounds of the base memref - auto [baseLow, baseHigh] = computeLinearBounds(builder, loc, baseMemref); - - // Compute the linear bounds of the resulting memref - auto [low, high] = computeLinearBounds(builder, loc, resultMemref); - - // Check low >= baseLow - auto geLow = builder.createOrFold( - loc, arith::CmpIPredicate::sge, low, baseLow); - - // Check high <= baseHigh - auto leHigh = builder.createOrFold( - loc, arith::CmpIPredicate::sle, high, baseHigh); - - auto assertCond = builder.createOrFold(loc, geLow, leHigh); - - builder.create( - loc, assertCond, - RuntimeVerifiableOpInterface::generateErrorMessage( - op, - "result of reinterpret_cast is out-of-bounds of the base memref")); - } -}; - struct SubViewOpInterface : public RuntimeVerifiableOpInterface::ExternalModel { @@ -431,9 +359,9 @@ void mlir::memref::registerRuntimeVerifiableOpInterfaceExternalModels( DimOp::attachInterface(*ctx); ExpandShapeOp::attachInterface(*ctx); LoadOp::attachInterface>(*ctx); - ReinterpretCastOp::attachInterface(*ctx); StoreOp::attachInterface>(*ctx); SubViewOp::attachInterface(*ctx); + // Note: There is nothing to verify for ReinterpretCastOp. // Load additional dialects of which ops may get created. ctx->loadDialect&1 | \ -// RUN: FileCheck %s - -func.func @reinterpret_cast(%memref: memref<1xf32>, %offset: index) { - memref.reinterpret_cast %memref to - offset: [%offset], - sizes: [1], - strides: [1] - : memref<1xf32> to memref<1xf32, strided<[1], offset: ?>> - return -} - -func.func @reinterpret_cast_fully_dynamic(%memref: memref, %offset: index, %size: index, %stride: index) { - memref.reinterpret_cast %memref to - offset: [%offset], - sizes: [%size], - strides: [%stride] - : memref to memref> - return -} - -func.func @main() { - %0 = arith.constant 0 : index - %1 = arith.constant 1 : index - %n1 = arith.constant -1 : index - %4 = arith.constant 4 : index - %5 = arith.constant 5 : index - - %alloca_1 = memref.alloca() : memref<1xf32> - %alloca_4 = memref.alloca() : memref<4xf32> - %alloca_4_dyn = memref.cast %alloca_4 : memref<4xf32> to memref - - // Offset is out-of-bounds - // CHECK: ERROR: Runtime op verification failed - // CHECK-NEXT: "memref.reinterpret_cast"(%{{.*}}) - // CHECK-NEXT: ^ result of reinterpret_cast is out-of-bounds of the base memref - // CHECK-NEXT: Location: loc({{.*}}) - func.call @reinterpret_cast(%alloca_1, %1) : (memref<1xf32>, index) -> () - - // Offset is out-of-bounds - // CHECK: ERROR: Runtime op verification failed - // CHECK-NEXT: "memref.reinterpret_cast"(%{{.*}}) - // CHECK-NEXT: ^ result of reinterpret_cast is out-of-bounds of the base memref - // CHECK-NEXT: Location: loc({{.*}}) - func.call @reinterpret_cast(%alloca_1, %n1) : (memref<1xf32>, index) -> () - - // Size is out-of-bounds - // CHECK: ERROR: Runtime op verification failed - // CHECK-NEXT: "memref.reinterpret_cast"(%{{.*}}) - // CHECK-NEXT: ^ result of reinterpret_cast is out-of-bounds of the base memref - // CHECK-NEXT: Location: loc({{.*}}) - func.call @reinterpret_cast_fully_dynamic(%alloca_4_dyn, %0, %5, %1) : (memref, index, index, index) -> () - - // Stride is out-of-bounds - // CHECK: ERROR: Runtime op verification failed - // CHECK-NEXT: "memref.reinterpret_cast"(%{{.*}}) - // CHECK-NEXT: ^ result of reinterpret_cast is out-of-bounds of the base memref - // CHECK-NEXT: Location: loc({{.*}}) - func.call @reinterpret_cast_fully_dynamic(%alloca_4_dyn, %0, %4, %4) : (memref, index, index, index) -> () - - // CHECK-NOT: ERROR: Runtime op verification failed - func.call @reinterpret_cast(%alloca_1, %0) : (memref<1xf32>, index) -> () - - // CHECK-NOT: ERROR: Runtime op verification failed - func.call @reinterpret_cast_fully_dynamic(%alloca_4_dyn, %0, %4, %1) : (memref, index, index, index) -> () - - return -}