Skip to content

Commit 8196459

Browse files
HanumanthHanumanth Hanumantharayappa
andauthored
[mlir][tensor] Fix runtime verification for tensor.extract_slice for empty tensor slices (llvm#166569)
I hit another runtime verification issue (similar to llvm#164878) while working with TFLite models. The verifier is incorrectly rejecting `tensor.extract_slice` operations when extracting an empty slice (size=0) that starts exactly at the tensor boundary. The current runtime verification unconditionally enforces `offset < dim_size`. This makes sense for non-empty slices, but it's too strict for empty slices, causing false positives that lead to spurious runtime assertions. **Simple example that demonstrates the issue:** ```mlir func.func @extract_empty_slice(%tensor: tensor<?xf32>, %offset: index, %size: index) { // When called with: tensor size=10, offset=10, size=0 // Runtime verification fails: "offset 0 is out-of-bounds" %slice = tensor.extract_slice %tensor[%offset] [%size] [1] : tensor<?xf32> to tensor<?xf32> return } ``` For the above example, the check evaluates `10 < 10` which is false, so verification fails. However, I believe this operation should be valid - we're extracting zero elements, so there's no actual out-of-bounds access. **Real-world repro from the TensorFlow Lite models:** This issue manifests while lowering TFLite models and a lot of our system tests are failing due to this. Here's a simplified version showing the problematic pattern: In this code, `%extracted_slice_0` becomes an empty tensor when SSA value `%15` reaches 10 (on the final loop iteration), making `%16 = 0`. The operation extracts zero elements along dimension 0, which is semantically valid but fails runtime verification. ```mlir func.func @simplified_repro_from_tensorflowlite_model(%arg0: tensor<10x4x1xf32>) -> tensor<10x4x1xf32> { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %c2 = arith.constant 2 : index %c10 = arith.constant 10 : index %c-1 = arith.constant -1 : index %0 = "tosa.const"() <{values = dense<0> : tensor<i32>}> : () -> tensor<i32> %1 = "tosa.const"() <{values = dense<1> : tensor<i32>}> : () -> tensor<i32> %2 = "tosa.const"() <{values = dense<10> : tensor<i32>}> : () -> tensor<i32> %3 = "tosa.const"() <{values = dense<-1> : tensor<2xi32>}> : () -> tensor<2xi32> %4 = "tosa.const"() <{values = dense<0> : tensor<2xi32>}> : () -> tensor<2xi32> %5 = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1x4x1xf32>}> : () -> tensor<1x4x1xf32> %c4_1 = tosa.const_shape {values = dense<1> : tensor<1xindex>} : () -> !tosa.shape<1> %6:2 = scf.while (%arg1 = %0, %arg2 = %arg0) : (tensor<i32>, tensor<10x4x1xf32>) -> (tensor<i32>, tensor<10x4x1xf32>) { %7 = tosa.greater %2, %arg1 : (tensor<i32>, tensor<i32>) -> tensor<i1> %extracted = tensor.extract %7[] : tensor<i1> scf.condition(%extracted) %arg1, %arg2 : tensor<i32>, tensor<10x4x1xf32> } do { ^bb0(%arg1: tensor<i32>, %arg2: tensor<10x4x1xf32>): %7 = tosa.add %arg1, %1 : (tensor<i32>, tensor<i32>) -> tensor<i32> // First slice %8 = tosa.reshape %arg1, %c4_1 : (tensor<i32>, !tosa.shape<1>) -> tensor<1xi32> %9 = tosa.concat %8, %3 {axis = 0 : i32} : (tensor<1xi32>, tensor<2xi32>) -> tensor<3xi32> %extracted_0 = tensor.extract %9[%c0] : tensor<3xi32> %10 = index.casts %extracted_0 : i32 to index %11 = arith.cmpi eq, %10, %c-1 : index %12 = arith.select %11, %c10, %10 : index %extracted_slice = tensor.extract_slice %arg2[0, 0, 0] [%12, 4, 1] [1, 1, 1] : tensor<10x4x1xf32> to tensor<?x4x1xf32> // Second slice - this is where the failure occurs %13 = tosa.reshape %7, %c4_1 : (tensor<i32>, !tosa.shape<1>) -> tensor<1xi32> %14 = tosa.concat %13, %4 {axis = 0 : i32} : (tensor<1xi32>, tensor<2xi32>) -> tensor<3xi32> %extracted_1 = tensor.extract %14[%c0] : tensor<3xi32> %15 = index.castu %extracted_1 : i32 to index %16 = arith.subi %c10, %15 : index // size = 10 - offset %extracted_2 = tensor.extract %14[%c1] : tensor<3xi32> %17 = index.castu %extracted_2 : i32 to index %extracted_3 = tensor.extract %14[%c2] : tensor<3xi32> %18 = index.castu %extracted_3 : i32 to index // On the last loop iteration: %15=10, %16=0 // %extracted_slice_0 becomes an empty tensor // Runtime verification fails: "offset 0 is out-of-bounds" %extracted_slice_0 = tensor.extract_slice %arg2[%15, %17, %18] [%16, 4, 1] [1, 1, 1] : tensor<10x4x1xf32> to tensor<?x4x1xf32> %19 = tosa.concat %extracted_slice, %5, %extracted_slice_0 {axis = 0 : i32} : (tensor<?x4x1xf32>, tensor<1x4x1xf32>, tensor<?x4x1xf32>) -> tensor<10x4x1xf32> scf.yield %7, %19 : tensor<i32>, tensor<10x4x1xf32> } return %6#1 : tensor<10x4x1xf32> } ``` **The fix:** Make the offset check conditional on slice size: - Empty slice (size == 0): allow `0 <= offset <= dim_size` - Non-empty slice (size > 0): require `0 <= offset < dim_size` **Question for reviewers:** Should we also relax the static verifier to allow this edge case? Currently, the static verifier rejects the following IR: ```mlir %tensor = arith.constant dense<1.0> : tensor<10xf32> %slice = tensor.extract_slice %tensor[10] [0] [1] : tensor<10xf32> to tensor<0xf32> ``` Since we're allowing it at runtime for dynamic shapes, it seems inconsistent to reject it statically. However, I wanted to get feedback before making that change - this PR focuses only on the runtime verification fix for dynamic shapes. P.S. We have a similar issue with `memref.subview`. I will send a separate patch for the issue. Co-authored-by: Hanumanth Hanumantharayappa <[email protected]>
1 parent a664f58 commit 8196459

File tree

2 files changed

+61
-33
lines changed

2 files changed

+61
-33
lines changed

mlir/lib/Dialect/Tensor/Transforms/RuntimeOpVerification.cpp

Lines changed: 52 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -155,13 +155,15 @@ struct ExtractSliceOpInterface
155155
RankedTensorType sourceType = extractSliceOp.getSource().getType();
156156

157157
// For each dimension, assert that:
158-
// 0 <= offset < dim_size
159-
// 0 <= offset + (size - 1) * stride < dim_size
158+
// For empty slices (size == 0) : 0 <= offset <= dim_size
159+
// For non-empty slices (size > 0): 0 <= offset < dim_size
160+
// 0 <= offset + (size - 1) * stride <
161+
// dim_size
160162
Value zero = arith::ConstantIndexOp::create(builder, loc, 0);
161163
Value one = arith::ConstantIndexOp::create(builder, loc, 1);
162164

163165
for (int64_t i : llvm::seq<int64_t>(0, sourceType.getRank())) {
164-
// Reset insertion point to before the operation for each dimension
166+
165167
builder.setInsertionPoint(extractSliceOp);
166168

167169
Value offset = getValueOrCreateConstantIndexOp(
@@ -170,46 +172,63 @@ struct ExtractSliceOpInterface
170172
builder, loc, extractSliceOp.getMixedSizes()[i]);
171173
Value stride = getValueOrCreateConstantIndexOp(
172174
builder, loc, extractSliceOp.getMixedStrides()[i]);
173-
174-
// Verify that offset is in-bounds.
175175
Value dimSize = builder.createOrFold<tensor::DimOp>(
176176
loc, extractSliceOp.getSource(), i);
177-
Value offsetInBounds =
178-
generateInBoundsCheck(builder, loc, offset, zero, dimSize);
179-
cf::AssertOp::create(builder, loc, offsetInBounds,
177+
178+
// Verify that offset is in-bounds (conditional on slice size).
179+
Value sizeIsZero = arith::CmpIOp::create(
180+
builder, loc, arith::CmpIPredicate::eq, size, zero);
181+
auto offsetCheckIf = scf::IfOp::create(
182+
builder, loc, sizeIsZero,
183+
[&](OpBuilder &b, Location loc) {
184+
// For empty slices, offset can be at the boundary: 0 <= offset <=
185+
// dimSize.
186+
Value offsetGEZero = arith::CmpIOp::create(
187+
b, loc, arith::CmpIPredicate::sge, offset, zero);
188+
Value offsetLEDimSize = arith::CmpIOp::create(
189+
b, loc, arith::CmpIPredicate::sle, offset, dimSize);
190+
Value emptyOffsetValid =
191+
arith::AndIOp::create(b, loc, offsetGEZero, offsetLEDimSize);
192+
scf::YieldOp::create(b, loc, emptyOffsetValid);
193+
},
194+
[&](OpBuilder &b, Location loc) {
195+
// For non-empty slices, offset must be a valid index: 0 <= offset <
196+
// dimSize.
197+
Value offsetInBounds =
198+
generateInBoundsCheck(b, loc, offset, zero, dimSize);
199+
scf::YieldOp::create(b, loc, offsetInBounds);
200+
});
201+
202+
Value offsetCondition = offsetCheckIf.getResult(0);
203+
cf::AssertOp::create(builder, loc, offsetCondition,
180204
generateErrorMessage(op, "offset " +
181205
std::to_string(i) +
182206
" is out-of-bounds"));
183207

184-
// Only verify if size > 0
208+
// Verify that the slice endpoint is in-bounds (only for non-empty
209+
// slices).
185210
Value sizeIsNonZero = arith::CmpIOp::create(
186211
builder, loc, arith::CmpIPredicate::sgt, size, zero);
212+
auto ifOp = scf::IfOp::create(
213+
builder, loc, sizeIsNonZero,
214+
[&](OpBuilder &b, Location loc) {
215+
// Verify that slice does not run out-of-bounds.
216+
Value sizeMinusOne = arith::SubIOp::create(b, loc, size, one);
217+
Value sizeMinusOneTimesStride =
218+
arith::MulIOp::create(b, loc, sizeMinusOne, stride);
219+
Value lastPos =
220+
arith::AddIOp::create(b, loc, offset, sizeMinusOneTimesStride);
221+
Value lastPosInBounds =
222+
generateInBoundsCheck(b, loc, lastPos, zero, dimSize);
223+
scf::YieldOp::create(b, loc, lastPosInBounds);
224+
},
225+
[&](OpBuilder &b, Location loc) {
226+
Value trueVal =
227+
arith::ConstantOp::create(b, loc, b.getBoolAttr(true));
228+
scf::YieldOp::create(b, loc, trueVal);
229+
});
187230

188-
auto ifOp = scf::IfOp::create(builder, loc, builder.getI1Type(),
189-
sizeIsNonZero, /*withElseRegion=*/true);
190-
191-
// Populate the "then" region (for size > 0).
192-
builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
193-
194-
// Verify that slice does not run out-of-bounds.
195-
Value sizeMinusOne = arith::SubIOp::create(builder, loc, size, one);
196-
Value sizeMinusOneTimesStride =
197-
arith::MulIOp::create(builder, loc, sizeMinusOne, stride);
198-
Value lastPos =
199-
arith::AddIOp::create(builder, loc, offset, sizeMinusOneTimesStride);
200-
Value lastPosInBounds =
201-
generateInBoundsCheck(builder, loc, lastPos, zero, dimSize);
202-
scf::YieldOp::create(builder, loc, lastPosInBounds);
203-
204-
// Populate the "else" region (for size == 0).
205-
builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
206-
Value trueVal =
207-
arith::ConstantOp::create(builder, loc, builder.getBoolAttr(true));
208-
scf::YieldOp::create(builder, loc, trueVal);
209-
210-
builder.setInsertionPointAfter(ifOp);
211231
Value finalCondition = ifOp.getResult(0);
212-
213232
cf::AssertOp::create(
214233
builder, loc, finalCondition,
215234
generateErrorMessage(

mlir/test/Integration/Dialect/Tensor/extract_slice-runtime-verification.mlir

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,11 @@ func.func @extract_slice_zero_size_dim(%arg0: tensor<10x4x1xf32>, %dim_0: index,
3939
return
4040
}
4141

42+
func.func @extract_slice_empty_tensor(%arg0: tensor<10x4x1xf32>, %dim_0: index, %dim_1: index, %dim_2: index, %offset: index) {
43+
tensor.extract_slice %arg0[%offset, 0, 0] [%dim_0, %dim_1, %dim_2] [1, 1, 1] : tensor<10x4x1xf32> to tensor<?x?x?xf32>
44+
return
45+
}
46+
4247

4348
func.func @main() {
4449
%0 = arith.constant 0 : index
@@ -115,5 +120,9 @@ func.func @main() {
115120
%dim_2 = arith.constant 1 : index
116121
func.call @extract_slice_zero_size_dim(%cst10x4x1xf32, %dim_0, %dim_1, %dim_2) : (tensor<10x4x1xf32>, index, index, index) -> ()
117122

123+
// CHECK-NOT: ERROR: Runtime op verification failed
124+
%offset = arith.constant 10 : index
125+
func.call @extract_slice_empty_tensor(%cst10x4x1xf32, %dim_0, %dim_1, %dim_2, %offset) : (tensor<10x4x1xf32>, index, index, index, index) -> ()
126+
118127
return
119128
}

0 commit comments

Comments
 (0)