Skip to content

Commit 84fec7c

Browse files
author
Hanumanth Hanumantharayappa
committed
[mlir][tensor] Fix runtime verification for empty tensor slices at boundary positions
1 parent 1667feb commit 84fec7c

File tree

2 files changed

+61
-33
lines changed

2 files changed

+61
-33
lines changed

mlir/lib/Dialect/Tensor/Transforms/RuntimeOpVerification.cpp

Lines changed: 52 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -155,13 +155,15 @@ struct ExtractSliceOpInterface
155155
RankedTensorType sourceType = extractSliceOp.getSource().getType();
156156

157157
// For each dimension, assert that:
158-
// 0 <= offset < dim_size
159-
// 0 <= offset + (size - 1) * stride < dim_size
158+
// For empty slices (size == 0) : 0 <= offset <= dim_size
159+
// For non-empty slices (size > 0): 0 <= offset < dim_size
160+
// 0 <= offset + (size - 1) * stride <
161+
// dim_size
160162
Value zero = arith::ConstantIndexOp::create(builder, loc, 0);
161163
Value one = arith::ConstantIndexOp::create(builder, loc, 1);
162164

163165
for (int64_t i : llvm::seq<int64_t>(0, sourceType.getRank())) {
164-
// Reset insertion point to before the operation for each dimension
166+
165167
builder.setInsertionPoint(extractSliceOp);
166168

167169
Value offset = getValueOrCreateConstantIndexOp(
@@ -170,46 +172,63 @@ struct ExtractSliceOpInterface
170172
builder, loc, extractSliceOp.getMixedSizes()[i]);
171173
Value stride = getValueOrCreateConstantIndexOp(
172174
builder, loc, extractSliceOp.getMixedStrides()[i]);
173-
174-
// Verify that offset is in-bounds.
175175
Value dimSize = builder.createOrFold<tensor::DimOp>(
176176
loc, extractSliceOp.getSource(), i);
177-
Value offsetInBounds =
178-
generateInBoundsCheck(builder, loc, offset, zero, dimSize);
179-
cf::AssertOp::create(builder, loc, offsetInBounds,
177+
178+
// Verify that offset is in-bounds (conditional on slice size).
179+
Value sizeIsZero = arith::CmpIOp::create(
180+
builder, loc, arith::CmpIPredicate::eq, size, zero);
181+
auto offsetCheckIf = scf::IfOp::create(
182+
builder, loc, sizeIsZero,
183+
[&](OpBuilder &b, Location loc) {
184+
// For empty slices, offset can be at the boundary: 0 <= offset <=
185+
// dimSize.
186+
Value offsetGEZero = arith::CmpIOp::create(
187+
b, loc, arith::CmpIPredicate::sge, offset, zero);
188+
Value offsetLEDimSize = arith::CmpIOp::create(
189+
b, loc, arith::CmpIPredicate::sle, offset, dimSize);
190+
Value emptyOffsetValid =
191+
arith::AndIOp::create(b, loc, offsetGEZero, offsetLEDimSize);
192+
scf::YieldOp::create(b, loc, emptyOffsetValid);
193+
},
194+
[&](OpBuilder &b, Location loc) {
195+
// For non-empty slices, offset must be a valid index: 0 <= offset <
196+
// dimSize.
197+
Value offsetInBounds =
198+
generateInBoundsCheck(b, loc, offset, zero, dimSize);
199+
scf::YieldOp::create(b, loc, offsetInBounds);
200+
});
201+
202+
Value offsetCondition = offsetCheckIf.getResult(0);
203+
cf::AssertOp::create(builder, loc, offsetCondition,
180204
generateErrorMessage(op, "offset " +
181205
std::to_string(i) +
182206
" is out-of-bounds"));
183207

184-
// Only verify if size > 0
208+
// Verify that the slice endpoint is in-bounds (only for non-empty
209+
// slices).
185210
Value sizeIsNonZero = arith::CmpIOp::create(
186211
builder, loc, arith::CmpIPredicate::sgt, size, zero);
212+
auto ifOp = scf::IfOp::create(
213+
builder, loc, sizeIsNonZero,
214+
[&](OpBuilder &b, Location loc) {
215+
// Verify that slice does not run out-of-bounds.
216+
Value sizeMinusOne = arith::SubIOp::create(b, loc, size, one);
217+
Value sizeMinusOneTimesStride =
218+
arith::MulIOp::create(b, loc, sizeMinusOne, stride);
219+
Value lastPos =
220+
arith::AddIOp::create(b, loc, offset, sizeMinusOneTimesStride);
221+
Value lastPosInBounds =
222+
generateInBoundsCheck(b, loc, lastPos, zero, dimSize);
223+
scf::YieldOp::create(b, loc, lastPosInBounds);
224+
},
225+
[&](OpBuilder &b, Location loc) {
226+
Value trueVal =
227+
arith::ConstantOp::create(b, loc, b.getBoolAttr(true));
228+
scf::YieldOp::create(b, loc, trueVal);
229+
});
187230

188-
auto ifOp = scf::IfOp::create(builder, loc, builder.getI1Type(),
189-
sizeIsNonZero, /*withElseRegion=*/true);
190-
191-
// Populate the "then" region (for size > 0).
192-
builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
193-
194-
// Verify that slice does not run out-of-bounds.
195-
Value sizeMinusOne = arith::SubIOp::create(builder, loc, size, one);
196-
Value sizeMinusOneTimesStride =
197-
arith::MulIOp::create(builder, loc, sizeMinusOne, stride);
198-
Value lastPos =
199-
arith::AddIOp::create(builder, loc, offset, sizeMinusOneTimesStride);
200-
Value lastPosInBounds =
201-
generateInBoundsCheck(builder, loc, lastPos, zero, dimSize);
202-
scf::YieldOp::create(builder, loc, lastPosInBounds);
203-
204-
// Populate the "else" region (for size == 0).
205-
builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
206-
Value trueVal =
207-
arith::ConstantOp::create(builder, loc, builder.getBoolAttr(true));
208-
scf::YieldOp::create(builder, loc, trueVal);
209-
210-
builder.setInsertionPointAfter(ifOp);
211231
Value finalCondition = ifOp.getResult(0);
212-
213232
cf::AssertOp::create(
214233
builder, loc, finalCondition,
215234
generateErrorMessage(

mlir/test/Integration/Dialect/Tensor/extract_slice-runtime-verification.mlir

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,11 @@ func.func @extract_slice_zero_size_dim(%arg0: tensor<10x4x1xf32>, %dim_0: index,
3939
return
4040
}
4141

42+
func.func @extract_slice_empty_tensor(%arg0: tensor<10x4x1xf32>, %dim_0: index, %dim_1: index, %dim_2: index, %offset: index) {
43+
tensor.extract_slice %arg0[%offset, 0, 0] [%dim_0, %dim_1, %dim_2] [1, 1, 1] : tensor<10x4x1xf32> to tensor<?x?x?xf32>
44+
return
45+
}
46+
4247

4348
func.func @main() {
4449
%0 = arith.constant 0 : index
@@ -115,5 +120,9 @@ func.func @main() {
115120
%dim_2 = arith.constant 1 : index
116121
func.call @extract_slice_zero_size_dim(%cst10x4x1xf32, %dim_0, %dim_1, %dim_2) : (tensor<10x4x1xf32>, index, index, index) -> ()
117122

123+
// CHECK-NOT: ERROR: Runtime op verification failed
124+
%offset = arith.constant 10 : index
125+
func.call @extract_slice_empty_tensor(%cst10x4x1xf32, %dim_0, %dim_1, %dim_2, %offset) : (tensor<10x4x1xf32>, index, index, index, index) -> ()
126+
118127
return
119128
}

0 commit comments

Comments
 (0)