Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 55 additions & 34 deletions mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -268,61 +268,82 @@ struct SubViewOpInterface
MemRefType sourceType = subView.getSource().getType();

// For each dimension, assert that:
// 0 <= offset < dim_size
// 0 <= offset + (size - 1) * stride < dim_size
// For empty slices (size == 0) : 0 <= offset <= dim_size
// For non-empty slices (size > 0): 0 <= offset < dim_size
// 0 <= offset + (size - 1) * stride
// dim_size
Value zero = arith::ConstantIndexOp::create(builder, loc, 0);
Value one = arith::ConstantIndexOp::create(builder, loc, 1);

auto metadataOp =
ExtractStridedMetadataOp::create(builder, loc, subView.getSource());

for (int64_t i : llvm::seq<int64_t>(0, sourceType.getRank())) {
// Reset insertion point to before the operation for each dimension
// Reset insertion point to before the operation for each dimension.
builder.setInsertionPoint(subView);

Value offset = getValueOrCreateConstantIndexOp(
builder, loc, subView.getMixedOffsets()[i]);
Value size = getValueOrCreateConstantIndexOp(builder, loc,
subView.getMixedSizes()[i]);
Value stride = getValueOrCreateConstantIndexOp(
builder, loc, subView.getMixedStrides()[i]);

// Verify that offset is in-bounds.
Value dimSize = metadataOp.getSizes()[i];
Value offsetInBounds =
generateInBoundsCheck(builder, loc, offset, zero, dimSize);
cf::AssertOp::create(builder, loc, offsetInBounds,

// Verify that offset is in-bounds (conditional on slice size).
Value sizeIsZero = arith::CmpIOp::create(
builder, loc, arith::CmpIPredicate::eq, size, zero);
auto offsetCheckIf = scf::IfOp::create(
builder, loc, sizeIsZero,
[&](OpBuilder &b, Location loc) {
// For empty slices, offset can be at the boundary: 0 <= offset <=
// dimSize.
Value offsetGEZero = arith::CmpIOp::create(
b, loc, arith::CmpIPredicate::sge, offset, zero);
Value offsetLEDimSize = arith::CmpIOp::create(
b, loc, arith::CmpIPredicate::sle, offset, dimSize);
Value emptyOffsetValid =
arith::AndIOp::create(b, loc, offsetGEZero, offsetLEDimSize);
scf::YieldOp::create(b, loc, emptyOffsetValid);
},
[&](OpBuilder &b, Location loc) {
// For non-empty slices, offset must be a valid index: 0 <= offset
// dimSize.
Value offsetInBounds =
generateInBoundsCheck(b, loc, offset, zero, dimSize);
scf::YieldOp::create(b, loc, offsetInBounds);
});

Value offsetCondition = offsetCheckIf.getResult(0);
cf::AssertOp::create(builder, loc, offsetCondition,
generateErrorMessage(op, "offset " +
std::to_string(i) +
" is out-of-bounds"));

// Only verify if size > 0
// Verify that the slice endpoint is in-bounds (only for non-empty
// slices).
Value sizeIsNonZero = arith::CmpIOp::create(
builder, loc, arith::CmpIPredicate::sgt, size, zero);
auto ifOp = scf::IfOp::create(
builder, loc, sizeIsNonZero,
[&](OpBuilder &b, Location loc) {
// Verify that slice does not run out-of-bounds.
Value sizeMinusOne = arith::SubIOp::create(b, loc, size, one);
Value sizeMinusOneTimesStride =
arith::MulIOp::create(b, loc, sizeMinusOne, stride);
Value lastPos =
arith::AddIOp::create(b, loc, offset, sizeMinusOneTimesStride);
Value lastPosInBounds =
generateInBoundsCheck(b, loc, lastPos, zero, dimSize);
scf::YieldOp::create(b, loc, lastPosInBounds);
},
[&](OpBuilder &b, Location loc) {
Value trueVal =
arith::ConstantOp::create(b, loc, b.getBoolAttr(true));
scf::YieldOp::create(b, loc, trueVal);
});

auto ifOp = scf::IfOp::create(builder, loc, builder.getI1Type(),
sizeIsNonZero, /*withElseRegion=*/true);

// Populate the "then" region (for size > 0).
builder.setInsertionPointToStart(&ifOp.getThenRegion().front());

// Verify that slice does not run out-of-bounds.
Value sizeMinusOne = arith::SubIOp::create(builder, loc, size, one);
Value sizeMinusOneTimesStride =
arith::MulIOp::create(builder, loc, sizeMinusOne, stride);
Value lastPos =
arith::AddIOp::create(builder, loc, offset, sizeMinusOneTimesStride);
Value lastPosInBounds =
generateInBoundsCheck(builder, loc, lastPos, zero, dimSize);

scf::YieldOp::create(builder, loc, lastPosInBounds);

// Populate the "else" region (for size == 0).
builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
Value trueVal =
arith::ConstantOp::create(builder, loc, builder.getBoolAttr(true));
scf::YieldOp::create(builder, loc, trueVal);

builder.setInsertionPointAfter(ifOp);
Value finalCondition = ifOp.getResult(0);

cf::AssertOp::create(
builder, loc, finalCondition,
generateErrorMessage(op,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,17 @@ func.func @subview_zero_size_dim(%memref: memref<10x4x1xf32, strided<[?, ?, ?],
return
}

func.func @subview_with_empty_slice(%memref: memref<10x4x1xf32, strided<[?, ?, ?], offset: ?>>,
%dim_0: index,
%dim_1: index,
%dim_2: index,
%offset: index) {
%subview = memref.subview %memref[%offset, 0, 0] [%dim_0, %dim_1, %dim_2] [1, 1, 1] :
memref<10x4x1xf32, strided<[?, ?, ?], offset: ?>> to
memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>>
return
}


func.func @main() {
%0 = arith.constant 0 : index
Expand Down Expand Up @@ -127,5 +138,9 @@ func.func @main() {
func.call @subview_zero_size_dim(%alloca_10x4x1_dyn_stride, %dim_0, %dim_1, %dim_2)
: (memref<10x4x1xf32, strided<[?, ?, ?], offset: ?>>, index, index, index) -> ()

// CHECK-NOT: ERROR: Runtime op verification failed
%offset = arith.constant 10 : index
func.call @subview_with_empty_slice(%alloca_10x4x1_dyn_stride, %dim_0, %dim_1, %dim_2, %offset)
: (memref<10x4x1xf32, strided<[?, ?, ?], offset: ?>>, index, index, index, index) -> ()
return
}