Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 42 additions & 1 deletion mlir/lib/Dialect/Tensor/Transforms/RuntimeOpVerification.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,11 @@ struct ExtractSliceOpInterface
// 0 <= offset + (size - 1) * stride < dim_size
Value zero = arith::ConstantIndexOp::create(builder, loc, 0);
Value one = arith::ConstantIndexOp::create(builder, loc, 1);
for (int64_t i = 0, e = sourceType.getRank(); i < e; ++i) {

for (int64_t i : llvm::seq<int64_t>(0, sourceType.getRank())) {
// Reset insertion point to before the operation for each dimension
builder.setInsertionPoint(extractSliceOp);

Value offset = getValueOrCreateConstantIndexOp(
builder, loc, extractSliceOp.getMixedOffsets()[i]);
Value size = getValueOrCreateConstantIndexOp(
Expand All @@ -176,6 +180,42 @@ struct ExtractSliceOpInterface
std::to_string(i) +
" is out-of-bounds"));

// Only verify if size > 0
Value sizeIsNonZero = arith::CmpIOp::create(
builder, loc, arith::CmpIPredicate::sgt, size, zero);

/*
* Split the current block to create the below control flow structure:
*
* ^preCondBlock:
* ... // offset check already done above
* %size_nonzero = arith.cmpi sgt, %size, %zero
* cf.cond_br %size_nonzero, ^sizeBoundsCheckBlock, ^afterCheckBlock
*
* ^sizeBoundsCheckBlock:
* %last_pos = ... // compute offset + (size-1) * stride
* %last_pos_ok = ... // last position bounds check
* cf.assert %last_pos_ok, "extract_slice runs out-of-bounds"
* cf.br ^afterCheckBlock
*
* ^afterCheckBlock:
* tensor.extract_slice ... // the original operation
*/
Block *preCondBlock = builder.getBlock();
Block *afterCheckBlock = preCondBlock->splitBlock(extractSliceOp);

// Create the block for conditional size bounds verification.
Block *sizeBoundsCheckBlock = builder.createBlock(
preCondBlock->getParent(), Region::iterator(afterCheckBlock));

// Terminate the pre-condition block with the conditional branch.
builder.setInsertionPointToEnd(preCondBlock);
cf::CondBranchOp::create(builder, loc, sizeIsNonZero,
sizeBoundsCheckBlock, afterCheckBlock);

// Populate the size bounds check block with lastPos verification.
builder.setInsertionPointToStart(sizeBoundsCheckBlock);

// Verify that slice does not run out-of-bounds.
Value sizeMinusOne = arith::SubIOp::create(builder, loc, size, one);
Value sizeMinusOneTimesStride =
Expand All @@ -189,6 +229,7 @@ struct ExtractSliceOpInterface
generateErrorMessage(
op, "extract_slice runs out-of-bounds along dimension " +
std::to_string(i)));
cf::BranchOp::create(builder, loc, afterCheckBlock);
}
}
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,12 @@ func.func @extract_slice_dynamic_rank_reduce(%tensor: tensor<?x4xf32>, %offset:
return
}

func.func @extract_slice_zero_size_dim(%arg0: tensor<10x4x1xf32>, %dim_0: index, %dim_1: index, %dim_2: index) {
tensor.extract_slice %arg0[0, 0, 0] [%dim_0, %dim_1, %dim_2] [1, 1, 1] : tensor<10x4x1xf32> to tensor<?x?x?xf32>
return
}


func.func @main() {
%0 = arith.constant 0 : index
%1 = arith.constant 1 : index
Expand Down Expand Up @@ -101,6 +107,13 @@ func.func @main() {
// CHECK-NOT: ERROR: Runtime op verification failed
func.call @extract_slice_dynamic_rank_reduce(%alloca_4_dyn, %0, %1, %0) : (tensor<?x4xf32>, index, index, index) -> ()

%cst10x4x1xf32 = arith.constant dense<1.0> : tensor<10x4x1xf32>

// CHECK-NOT: ERROR: Runtime op verification failed
%dim_0 = arith.constant 0 : index
%dim_1 = arith.constant 4 : index
%dim_2 = arith.constant 1 : index
func.call @extract_slice_zero_size_dim(%cst10x4x1xf32, %dim_0, %dim_1, %dim_2) : (tensor<10x4x1xf32>, index, index, index) -> ()

return
}