You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
[mlir][tensor] Fix runtime verification for tensor.extract_slice for empty tensor slices (llvm#166569)
I hit another runtime verification issue (similar to
llvm#164878) while working with
TFLite models. The verifier is incorrectly rejecting
`tensor.extract_slice` operations when extracting an empty slice
(size=0) that starts exactly at the tensor boundary.
The current runtime verification unconditionally enforces `offset <
dim_size`. This makes sense for non-empty slices, but it's too strict
for empty slices, causing false positives that lead to spurious runtime
assertions.
**Simple example that demonstrates the issue:**
```mlir
func.func @extract_empty_slice(%tensor: tensor<?xf32>, %offset: index, %size: index) {
// When called with: tensor size=10, offset=10, size=0
// Runtime verification fails: "offset 0 is out-of-bounds"
%slice = tensor.extract_slice %tensor[%offset] [%size] [1]
: tensor<?xf32> to tensor<?xf32>
return
}
```
For the above example, the check evaluates `10 < 10` which is false, so
verification fails. However, I believe this operation should be valid -
we're extracting zero elements, so there's no actual out-of-bounds
access.
**Real-world repro from the TensorFlow Lite models:**
This issue manifests while lowering TFLite models and a lot of our
system tests are failing due to this. Here's a simplified version
showing the problematic pattern:
In this code, `%extracted_slice_0` becomes an empty tensor when SSA
value `%15` reaches 10 (on the final loop iteration), making `%16 = 0`.
The operation extracts zero elements along dimension 0, which is
semantically valid but fails runtime verification.
```mlir
func.func @simplified_repro_from_tensorflowlite_model(%arg0: tensor<10x4x1xf32>) -> tensor<10x4x1xf32> {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
%c10 = arith.constant 10 : index
%c-1 = arith.constant -1 : index
%0 = "tosa.const"() <{values = dense<0> : tensor<i32>}> : () -> tensor<i32>
%1 = "tosa.const"() <{values = dense<1> : tensor<i32>}> : () -> tensor<i32>
%2 = "tosa.const"() <{values = dense<10> : tensor<i32>}> : () -> tensor<i32>
%3 = "tosa.const"() <{values = dense<-1> : tensor<2xi32>}> : () -> tensor<2xi32>
%4 = "tosa.const"() <{values = dense<0> : tensor<2xi32>}> : () -> tensor<2xi32>
%5 = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1x4x1xf32>}> : () -> tensor<1x4x1xf32>
%c4_1 = tosa.const_shape {values = dense<1> : tensor<1xindex>} : () -> !tosa.shape<1>
%6:2 = scf.while (%arg1 = %0, %arg2 = %arg0)
: (tensor<i32>, tensor<10x4x1xf32>) -> (tensor<i32>, tensor<10x4x1xf32>) {
%7 = tosa.greater %2, %arg1 : (tensor<i32>, tensor<i32>) -> tensor<i1>
%extracted = tensor.extract %7[] : tensor<i1>
scf.condition(%extracted) %arg1, %arg2 : tensor<i32>, tensor<10x4x1xf32>
} do {
^bb0(%arg1: tensor<i32>, %arg2: tensor<10x4x1xf32>):
%7 = tosa.add %arg1, %1 : (tensor<i32>, tensor<i32>) -> tensor<i32>
// First slice
%8 = tosa.reshape %arg1, %c4_1 : (tensor<i32>, !tosa.shape<1>) -> tensor<1xi32>
%9 = tosa.concat %8, %3 {axis = 0 : i32} : (tensor<1xi32>, tensor<2xi32>) -> tensor<3xi32>
%extracted_0 = tensor.extract %9[%c0] : tensor<3xi32>
%10 = index.casts %extracted_0 : i32 to index
%11 = arith.cmpi eq, %10, %c-1 : index
%12 = arith.select %11, %c10, %10 : index
%extracted_slice = tensor.extract_slice %arg2[0, 0, 0] [%12, 4, 1] [1, 1, 1]
: tensor<10x4x1xf32> to tensor<?x4x1xf32>
// Second slice - this is where the failure occurs
%13 = tosa.reshape %7, %c4_1 : (tensor<i32>, !tosa.shape<1>) -> tensor<1xi32>
%14 = tosa.concat %13, %4 {axis = 0 : i32} : (tensor<1xi32>, tensor<2xi32>) -> tensor<3xi32>
%extracted_1 = tensor.extract %14[%c0] : tensor<3xi32>
%15 = index.castu %extracted_1 : i32 to index
%16 = arith.subi %c10, %15 : index // size = 10 - offset
%extracted_2 = tensor.extract %14[%c1] : tensor<3xi32>
%17 = index.castu %extracted_2 : i32 to index
%extracted_3 = tensor.extract %14[%c2] : tensor<3xi32>
%18 = index.castu %extracted_3 : i32 to index
// On the last loop iteration: %15=10, %16=0
// %extracted_slice_0 becomes an empty tensor
// Runtime verification fails: "offset 0 is out-of-bounds"
%extracted_slice_0 = tensor.extract_slice %arg2[%15, %17, %18] [%16, 4, 1] [1, 1, 1]
: tensor<10x4x1xf32> to tensor<?x4x1xf32>
%19 = tosa.concat %extracted_slice, %5, %extracted_slice_0 {axis = 0 : i32}
: (tensor<?x4x1xf32>, tensor<1x4x1xf32>, tensor<?x4x1xf32>) -> tensor<10x4x1xf32>
scf.yield %7, %19 : tensor<i32>, tensor<10x4x1xf32>
}
return %6#1 : tensor<10x4x1xf32>
}
```
**The fix:**
Make the offset check conditional on slice size:
- Empty slice (size == 0): allow `0 <= offset <= dim_size`
- Non-empty slice (size > 0): require `0 <= offset < dim_size`
**Question for reviewers:**
Should we also relax the static verifier to allow this edge case?
Currently, the static verifier rejects the following IR:
```mlir
%tensor = arith.constant dense<1.0> : tensor<10xf32>
%slice = tensor.extract_slice %tensor[10] [0] [1] : tensor<10xf32> to tensor<0xf32>
```
Since we're allowing it at runtime for dynamic shapes, it seems
inconsistent to reject it statically. However, I wanted to get feedback
before making that change - this PR focuses only on the runtime
verification fix for dynamic shapes.
P.S. We have a similar issue with `memref.subview`. I will send a
separate patch for the issue.
Co-authored-by: Hanumanth Hanumantharayappa <[email protected]>
0 commit comments