diff --git a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp index 134e8b5efcfdf..4537977226087 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp @@ -327,47 +327,52 @@ struct ReinterpretCastOpInterface } }; -/// Verifies that the linear bounds of a subview op are within the linear bounds -/// of the base memref: low >= baseLow && high <= baseHigh -/// TODO: This is not yet a full runtime verification of subview. For example, -/// consider: -/// %m = memref.alloc(%c10, %c10) : memref<10x10xf32> -/// memref.subview %m[%c0, %c0][%c20, %c2][%c1, %c1] -/// : memref to memref -/// The subview is in-bounds of the entire base memref but the first dimension -/// is out-of-bounds. Future work would verify the bounds on a per-dimension -/// basis. struct SubViewOpInterface : public RuntimeVerifiableOpInterface::ExternalModel { void generateRuntimeVerification(Operation *op, OpBuilder &builder, Location loc) const { auto subView = cast(op); - auto baseMemref = cast>(subView.getSource()); - auto resultMemref = cast>(subView.getResult()); + MemRefType sourceType = subView.getSource().getType(); - builder.setInsertionPointAfter(op); - - // Compute the linear bounds of the base memref - auto [baseLow, baseHigh] = computeLinearBounds(builder, loc, baseMemref); - - // Compute the linear bounds of the resulting memref - auto [low, high] = computeLinearBounds(builder, loc, resultMemref); - - // Check low >= baseLow - auto geLow = builder.createOrFold( - loc, arith::CmpIPredicate::sge, low, baseLow); - - // Check high <= baseHigh - auto leHigh = builder.createOrFold( - loc, arith::CmpIPredicate::sle, high, baseHigh); - - auto assertCond = builder.createOrFold(loc, geLow, leHigh); - - builder.create( - loc, assertCond, - RuntimeVerifiableOpInterface::generateErrorMessage( - op, "subview is out-of-bounds of the base memref")); + // For each dimension, assert that: + // 0 <= offset < dim_size + // 0 <= offset + (size - 1) * stride < dim_size + Value zero = builder.create(loc, 0); + Value one = builder.create(loc, 1); + auto metadataOp = + builder.create(loc, subView.getSource()); + for (int64_t i = 0, e = sourceType.getRank(); i < e; ++i) { + Value offset = getValueOrCreateConstantIndexOp( + builder, loc, subView.getMixedOffsets()[i]); + Value size = getValueOrCreateConstantIndexOp(builder, loc, + subView.getMixedSizes()[i]); + Value stride = getValueOrCreateConstantIndexOp( + builder, loc, subView.getMixedStrides()[i]); + + // Verify that offset is in-bounds. + Value dimSize = metadataOp.getSizes()[i]; + Value offsetInBounds = + generateInBoundsCheck(builder, loc, offset, zero, dimSize); + builder.create( + loc, offsetInBounds, + RuntimeVerifiableOpInterface::generateErrorMessage( + op, "offset " + std::to_string(i) + " is out-of-bounds")); + + // Verify that slice does not run out-of-bounds. + Value sizeMinusOne = builder.create(loc, size, one); + Value sizeMinusOneTimesStride = + builder.create(loc, sizeMinusOne, stride); + Value lastPos = + builder.create(loc, offset, sizeMinusOneTimesStride); + Value lastPosInBounds = + generateInBoundsCheck(builder, loc, lastPos, zero, dimSize); + builder.create( + loc, lastPosInBounds, + RuntimeVerifiableOpInterface::generateErrorMessage( + op, "Subview runs out-of-bounds along dimension" + + std::to_string(i))); + } } }; diff --git a/mlir/test/Integration/Dialect/MemRef/subview-runtime-verification.mlir b/mlir/test/Integration/Dialect/MemRef/subview-runtime-verification.mlir index 3cac37a082c30..ec7e4085f2fa5 100644 --- a/mlir/test/Integration/Dialect/MemRef/subview-runtime-verification.mlir +++ b/mlir/test/Integration/Dialect/MemRef/subview-runtime-verification.mlir @@ -39,38 +39,50 @@ func.func @main() { %alloca_4 = memref.alloca() : memref<4x4xf32> %alloca_4_dyn = memref.cast %alloca_4 : memref<4x4xf32> to memref - // Offset is out-of-bounds + // Offset is out-of-bounds and slice runs out-of-bounds // CHECK: ERROR: Runtime op verification failed - // CHECK-NEXT: "memref.subview" - // CHECK-NEXT: ^ subview is out-of-bounds of the base memref + // CHECK-NEXT: "memref.subview"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) <{operandSegmentSizes = array, static_offsets = array, static_sizes = array, static_strides = array}> : (memref, index, index, index) -> memref> + // CHECK-NEXT: ^ offset 0 is out-of-bounds + // CHECK-NEXT: Location: loc({{.*}}) + // CHECK: ERROR: Runtime op verification failed + // CHECK-NEXT: "memref.subview"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) <{operandSegmentSizes = array, static_offsets = array, static_sizes = array, static_strides = array}> : (memref, index, index, index) -> memref> + // CHECK-NEXT: ^ subview runs out-of-bounds along dimension 0 // CHECK-NEXT: Location: loc({{.*}}) func.call @subview_dynamic_rank_reduce(%alloca_4_dyn, %5, %5, %1) : (memref, index, index, index) -> () - // Offset is out-of-bounds + // Offset is out-of-bounds and slice runs out-of-bounds + // CHECK: ERROR: Runtime op verification failed + // CHECK-NEXT: "memref.subview"(%{{.*}}, %{{.*}}) <{operandSegmentSizes = array, static_offsets = array, static_sizes = array, static_strides = array}> : (memref<1xf32>, index) -> memref<1xf32, strided<[1], offset: ?>> + // CHECK-NEXT: ^ offset 0 is out-of-bounds + // CHECK-NEXT: Location: loc({{.*}}) // CHECK: ERROR: Runtime op verification failed - // CHECK-NEXT: "memref.subview" - // CHECK-NEXT: ^ subview is out-of-bounds of the base memref + // CHECK-NEXT: "memref.subview"(%{{.*}}, %{{.*}}) <{operandSegmentSizes = array, static_offsets = array, static_sizes = array, static_strides = array}> : (memref<1xf32>, index) -> memref<1xf32, strided<[1], offset: ?>> + // CHECK-NEXT: ^ subview runs out-of-bounds along dimension 0 // CHECK-NEXT: Location: loc({{.*}}) func.call @subview(%alloca, %1) : (memref<1xf32>, index) -> () - // Offset is out-of-bounds + // Offset is out-of-bounds and slice runs out-of-bounds + // CHECK: ERROR: Runtime op verification failed + // CHECK-NEXT: "memref.subview"(%{{.*}}, %{{.*}}) <{operandSegmentSizes = array, static_offsets = array, static_sizes = array, static_strides = array}> : (memref<1xf32>, index) -> memref<1xf32, strided<[1], offset: ?>> + // CHECK-NEXT: ^ offset 0 is out-of-bounds + // CHECK-NEXT: Location: loc({{.*}}) // CHECK: ERROR: Runtime op verification failed - // CHECK-NEXT: "memref.subview" - // CHECK-NEXT: ^ subview is out-of-bounds of the base memref + // CHECK-NEXT: "memref.subview"(%{{.*}}, %{{.*}}) <{operandSegmentSizes = array, static_offsets = array, static_sizes = array, static_strides = array}> : (memref<1xf32>, index) -> memref<1xf32, strided<[1], offset: ?>> + // CHECK-NEXT: ^ subview runs out-of-bounds along dimension 0 // CHECK-NEXT: Location: loc({{.*}}) func.call @subview(%alloca, %n1) : (memref<1xf32>, index) -> () - // Size is out-of-bounds + // Slice runs out-of-bounds due to size // CHECK: ERROR: Runtime op verification failed - // CHECK-NEXT: "memref.subview" - // CHECK-NEXT: ^ subview is out-of-bounds of the base memref + // CHECK-NEXT: "memref.subview"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) <{operandSegmentSizes = array, static_offsets = array, static_sizes = array, static_strides = array}> : (memref, index, index, index) -> memref> + // CHECK-NEXT: ^ subview runs out-of-bounds along dimension 0 // CHECK-NEXT: Location: loc({{.*}}) func.call @subview_dynamic(%alloca_4_dyn, %0, %5, %1) : (memref, index, index, index) -> () - // Stride is out-of-bounds + // Slice runs out-of-bounds due to stride // CHECK: ERROR: Runtime op verification failed - // CHECK-NEXT: "memref.subview" - // CHECK-NEXT: ^ subview is out-of-bounds of the base memref + // CHECK-NEXT: "memref.subview"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) <{operandSegmentSizes = array, static_offsets = array, static_sizes = array, static_strides = array}> : (memref, index, index, index) -> memref> + // CHECK-NEXT: ^ subview runs out-of-bounds along dimension 0 // CHECK-NEXT: Location: loc({{.*}}) func.call @subview_dynamic(%alloca_4_dyn, %0, %4, %4) : (memref, index, index, index) -> ()