diff --git a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp index 14152c5a1af0c..e5cc41e2c43ba 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp @@ -268,61 +268,82 @@ struct SubViewOpInterface MemRefType sourceType = subView.getSource().getType(); // For each dimension, assert that: - // 0 <= offset < dim_size - // 0 <= offset + (size - 1) * stride < dim_size + // For empty slices (size == 0) : 0 <= offset <= dim_size + // For non-empty slices (size > 0): 0 <= offset < dim_size + // 0 <= offset + (size - 1) * stride + // dim_size Value zero = arith::ConstantIndexOp::create(builder, loc, 0); Value one = arith::ConstantIndexOp::create(builder, loc, 1); + auto metadataOp = ExtractStridedMetadataOp::create(builder, loc, subView.getSource()); + for (int64_t i : llvm::seq(0, sourceType.getRank())) { - // Reset insertion point to before the operation for each dimension + // Reset insertion point to before the operation for each dimension. builder.setInsertionPoint(subView); + Value offset = getValueOrCreateConstantIndexOp( builder, loc, subView.getMixedOffsets()[i]); Value size = getValueOrCreateConstantIndexOp(builder, loc, subView.getMixedSizes()[i]); Value stride = getValueOrCreateConstantIndexOp( builder, loc, subView.getMixedStrides()[i]); - - // Verify that offset is in-bounds. Value dimSize = metadataOp.getSizes()[i]; - Value offsetInBounds = - generateInBoundsCheck(builder, loc, offset, zero, dimSize); - cf::AssertOp::create(builder, loc, offsetInBounds, + + // Verify that offset is in-bounds (conditional on slice size). + Value sizeIsZero = arith::CmpIOp::create( + builder, loc, arith::CmpIPredicate::eq, size, zero); + auto offsetCheckIf = scf::IfOp::create( + builder, loc, sizeIsZero, + [&](OpBuilder &b, Location loc) { + // For empty slices, offset can be at the boundary: 0 <= offset <= + // dimSize. + Value offsetGEZero = arith::CmpIOp::create( + b, loc, arith::CmpIPredicate::sge, offset, zero); + Value offsetLEDimSize = arith::CmpIOp::create( + b, loc, arith::CmpIPredicate::sle, offset, dimSize); + Value emptyOffsetValid = + arith::AndIOp::create(b, loc, offsetGEZero, offsetLEDimSize); + scf::YieldOp::create(b, loc, emptyOffsetValid); + }, + [&](OpBuilder &b, Location loc) { + // For non-empty slices, offset must be a valid index: 0 <= offset + // dimSize. + Value offsetInBounds = + generateInBoundsCheck(b, loc, offset, zero, dimSize); + scf::YieldOp::create(b, loc, offsetInBounds); + }); + + Value offsetCondition = offsetCheckIf.getResult(0); + cf::AssertOp::create(builder, loc, offsetCondition, generateErrorMessage(op, "offset " + std::to_string(i) + " is out-of-bounds")); - // Only verify if size > 0 + // Verify that the slice endpoint is in-bounds (only for non-empty + // slices). Value sizeIsNonZero = arith::CmpIOp::create( builder, loc, arith::CmpIPredicate::sgt, size, zero); + auto ifOp = scf::IfOp::create( + builder, loc, sizeIsNonZero, + [&](OpBuilder &b, Location loc) { + // Verify that slice does not run out-of-bounds. + Value sizeMinusOne = arith::SubIOp::create(b, loc, size, one); + Value sizeMinusOneTimesStride = + arith::MulIOp::create(b, loc, sizeMinusOne, stride); + Value lastPos = + arith::AddIOp::create(b, loc, offset, sizeMinusOneTimesStride); + Value lastPosInBounds = + generateInBoundsCheck(b, loc, lastPos, zero, dimSize); + scf::YieldOp::create(b, loc, lastPosInBounds); + }, + [&](OpBuilder &b, Location loc) { + Value trueVal = + arith::ConstantOp::create(b, loc, b.getBoolAttr(true)); + scf::YieldOp::create(b, loc, trueVal); + }); - auto ifOp = scf::IfOp::create(builder, loc, builder.getI1Type(), - sizeIsNonZero, /*withElseRegion=*/true); - - // Populate the "then" region (for size > 0). - builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); - - // Verify that slice does not run out-of-bounds. - Value sizeMinusOne = arith::SubIOp::create(builder, loc, size, one); - Value sizeMinusOneTimesStride = - arith::MulIOp::create(builder, loc, sizeMinusOne, stride); - Value lastPos = - arith::AddIOp::create(builder, loc, offset, sizeMinusOneTimesStride); - Value lastPosInBounds = - generateInBoundsCheck(builder, loc, lastPos, zero, dimSize); - - scf::YieldOp::create(builder, loc, lastPosInBounds); - - // Populate the "else" region (for size == 0). - builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); - Value trueVal = - arith::ConstantOp::create(builder, loc, builder.getBoolAttr(true)); - scf::YieldOp::create(builder, loc, trueVal); - - builder.setInsertionPointAfter(ifOp); Value finalCondition = ifOp.getResult(0); - cf::AssertOp::create( builder, loc, finalCondition, generateErrorMessage(op, diff --git a/mlir/test/Integration/Dialect/MemRef/subview-runtime-verification.mlir b/mlir/test/Integration/Dialect/MemRef/subview-runtime-verification.mlir index 84875675ac3d0..09cfee16ccd00 100644 --- a/mlir/test/Integration/Dialect/MemRef/subview-runtime-verification.mlir +++ b/mlir/test/Integration/Dialect/MemRef/subview-runtime-verification.mlir @@ -50,6 +50,17 @@ func.func @subview_zero_size_dim(%memref: memref<10x4x1xf32, strided<[?, ?, ?], return } +func.func @subview_with_empty_slice(%memref: memref<10x4x1xf32, strided<[?, ?, ?], offset: ?>>, + %dim_0: index, + %dim_1: index, + %dim_2: index, + %offset: index) { + %subview = memref.subview %memref[%offset, 0, 0] [%dim_0, %dim_1, %dim_2] [1, 1, 1] : + memref<10x4x1xf32, strided<[?, ?, ?], offset: ?>> to + memref> + return +} + func.func @main() { %0 = arith.constant 0 : index @@ -127,5 +138,9 @@ func.func @main() { func.call @subview_zero_size_dim(%alloca_10x4x1_dyn_stride, %dim_0, %dim_1, %dim_2) : (memref<10x4x1xf32, strided<[?, ?, ?], offset: ?>>, index, index, index) -> () + // CHECK-NOT: ERROR: Runtime op verification failed + %offset = arith.constant 10 : index + func.call @subview_with_empty_slice(%alloca_10x4x1_dyn_stride, %dim_0, %dim_1, %dim_2, %offset) + : (memref<10x4x1xf32, strided<[?, ?, ?], offset: ?>>, index, index, index, index) -> () return }