Skip to content

Commit a6788b5

Browse files
Hanumanth04Hanumanth Hanumantharayappa
andauthored
[mlir][tensor] Fix runtime verification for tensor.extract_slice when size dimension value is 0 (#164878)
Previously, the runtime verification pass would insert assertion statements with conditions that always evaluate to false for semantically valid `tensor.extract_slice` operations where one of the dimensions had a size of 0. The `tensor.extract_slice` runtime verification logic was unconditionally generating checks for the position of the last element (`offset + (size - 1) * stride`). When `size` is 0, this causes the assertion condition to always be false, leading to runtime failures even though the operation is semantically valid. This patch fixes the issue by making the `lastPos` check conditional. The offset is always verified, but the endpoint check is only performed when `size > 0` to avoid generating spurious assert statements. This issue was discovered through LiteRT model, where a dynamic shape calculation resulted in a zero-sized dimension being passed to `tensor.extract_slice`. The following is a simplified IR snippet from the model. After running the runtime verification pass, an assertion that always fails is generated because the SSA value `%3` becomes 0. ```mlir func.func @simple_repro_from_liteRT_model(%arg0: tensor<10x4x1xf32>) -> tensor<?x?x?xf32> { %cst = arith.constant dense<0> : tensor<1xi32> %cst_0 = arith.constant dense<-1> : tensor<2xi32> %c-1 = arith.constant -1 : index %c0 = arith.constant 0 : index %c10 = arith.constant 10 : index %c1 = arith.constant 1 : index %c4 = arith.constant 4 : index %c2 = arith.constant 2 : index %0 = tensor.empty() : tensor<3xi32> %inserted_slice = tensor.insert_slice %cst into %0[0] [1] [1] : tensor<1xi32> into tensor<3xi32> %inserted_slice_1 = tensor.insert_slice %cst_0 into %inserted_slice[1] [2] [1] : tensor<2xi32> into tensor<3xi32> %extracted = tensor.extract %inserted_slice_1[%c0] : tensor<3xi32> %1 = index.casts %extracted : i32 to index %2 = arith.cmpi eq, %1, %c-1 : index %3 = arith.select %2, %c10, %1 : index %extracted_2 = tensor.extract %inserted_slice_1[%c1] : tensor<3xi32> %4 = index.casts %extracted_2 : i32 to index %5 = arith.cmpi eq, %4, %c-1 : index %6 = arith.select %5, %c4, %4 : index %extracted_3 = tensor.extract %inserted_slice_1[%c2] : tensor<3xi32> %7 = index.casts %extracted_3 : i32 to index %8 = arith.cmpi eq, %7, %c-1 : index %9 = arith.select %8, %c1, %7 : index %extracted_slice = tensor.extract_slice %arg0[0, 0, 0] [%3, %6, %9] [1, 1, 1] : tensor<10x4x1xf32> to tensor<?x?x?xf32> return %extracted_slice : tensor<?x?x?xf32> } ``` The issue can be reproduced more simply with the following test case, where `dim_0` is `0`. When the runtime verification pass is applied to this code with `dim_0 = 0`, it generates an assertion that will always fail at runtime. ```mlir func.func @extract_slice_zero_size_dim(%arg0: tensor<10x4x1xf32>, %dim_0: index, %dim_1: index, %dim_2: index) { %slice = tensor.extract_slice %arg0[0, 0, 0] [%dim_0, %dim_1, %dim_2] [1, 1, 1] : tensor<10x4x1xf32> to tensor<?x?x?xf32> return } func.func @test_zero_size_extraction() { %input = arith.constant dense<1.0> : tensor<10x4x1xf32> // Define slice dimensions: 0x4x1 (zero-size in first dimension) %dim_0 = arith.constant 0 : index %dim_1 = arith.constant 4 : index %dim_2 = arith.constant 1 : index func.call @extract_slice_zero_size_dim(%input, %dim_0, %dim_1, %dim_2) : (tensor<10x4x1xf32>, index, index, index) -> () return } ``` P.S. We probably have a similar issue with `memref.subview`. I will check this and send a separate PR for the issue. --------- Co-authored-by: Hanumanth Hanumantharayappa <[email protected]>
1 parent f8a0599 commit a6788b5

File tree

2 files changed

+41
-2
lines changed

2 files changed

+41
-2
lines changed

mlir/lib/Dialect/Tensor/Transforms/RuntimeOpVerification.cpp

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include "mlir/Dialect/Arith/Utils/Utils.h"
1313
#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
1414
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
15+
#include "mlir/Dialect/SCF/IR/SCF.h"
1516
#include "mlir/Dialect/Tensor/IR/Tensor.h"
1617
#include "mlir/Interfaces/RuntimeVerifiableOpInterface.h"
1718

@@ -158,7 +159,11 @@ struct ExtractSliceOpInterface
158159
// 0 <= offset + (size - 1) * stride < dim_size
159160
Value zero = arith::ConstantIndexOp::create(builder, loc, 0);
160161
Value one = arith::ConstantIndexOp::create(builder, loc, 1);
161-
for (int64_t i = 0, e = sourceType.getRank(); i < e; ++i) {
162+
163+
for (int64_t i : llvm::seq<int64_t>(0, sourceType.getRank())) {
164+
// Reset insertion point to before the operation for each dimension
165+
builder.setInsertionPoint(extractSliceOp);
166+
162167
Value offset = getValueOrCreateConstantIndexOp(
163168
builder, loc, extractSliceOp.getMixedOffsets()[i]);
164169
Value size = getValueOrCreateConstantIndexOp(
@@ -176,6 +181,16 @@ struct ExtractSliceOpInterface
176181
std::to_string(i) +
177182
" is out-of-bounds"));
178183

184+
// Only verify if size > 0
185+
Value sizeIsNonZero = arith::CmpIOp::create(
186+
builder, loc, arith::CmpIPredicate::sgt, size, zero);
187+
188+
auto ifOp = scf::IfOp::create(builder, loc, builder.getI1Type(),
189+
sizeIsNonZero, /*withElseRegion=*/true);
190+
191+
// Populate the "then" region (for size > 0).
192+
builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
193+
179194
// Verify that slice does not run out-of-bounds.
180195
Value sizeMinusOne = arith::SubIOp::create(builder, loc, size, one);
181196
Value sizeMinusOneTimesStride =
@@ -184,8 +199,19 @@ struct ExtractSliceOpInterface
184199
arith::AddIOp::create(builder, loc, offset, sizeMinusOneTimesStride);
185200
Value lastPosInBounds =
186201
generateInBoundsCheck(builder, loc, lastPos, zero, dimSize);
202+
scf::YieldOp::create(builder, loc, lastPosInBounds);
203+
204+
// Populate the "else" region (for size == 0).
205+
builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
206+
Value trueVal =
207+
arith::ConstantOp::create(builder, loc, builder.getBoolAttr(true));
208+
scf::YieldOp::create(builder, loc, trueVal);
209+
210+
builder.setInsertionPointAfter(ifOp);
211+
Value finalCondition = ifOp.getResult(0);
212+
187213
cf::AssertOp::create(
188-
builder, loc, lastPosInBounds,
214+
builder, loc, finalCondition,
189215
generateErrorMessage(
190216
op, "extract_slice runs out-of-bounds along dimension " +
191217
std::to_string(i)));

mlir/test/Integration/Dialect/Tensor/extract_slice-runtime-verification.mlir

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,12 @@ func.func @extract_slice_dynamic_rank_reduce(%tensor: tensor<?x4xf32>, %offset:
3434
return
3535
}
3636

37+
func.func @extract_slice_zero_size_dim(%arg0: tensor<10x4x1xf32>, %dim_0: index, %dim_1: index, %dim_2: index) {
38+
tensor.extract_slice %arg0[0, 0, 0] [%dim_0, %dim_1, %dim_2] [1, 1, 1] : tensor<10x4x1xf32> to tensor<?x?x?xf32>
39+
return
40+
}
41+
42+
3743
func.func @main() {
3844
%0 = arith.constant 0 : index
3945
%1 = arith.constant 1 : index
@@ -101,6 +107,13 @@ func.func @main() {
101107
// CHECK-NOT: ERROR: Runtime op verification failed
102108
func.call @extract_slice_dynamic_rank_reduce(%alloca_4_dyn, %0, %1, %0) : (tensor<?x4xf32>, index, index, index) -> ()
103109

110+
%cst10x4x1xf32 = arith.constant dense<1.0> : tensor<10x4x1xf32>
111+
112+
// CHECK-NOT: ERROR: Runtime op verification failed
113+
%dim_0 = arith.constant 0 : index
114+
%dim_1 = arith.constant 4 : index
115+
%dim_2 = arith.constant 1 : index
116+
func.call @extract_slice_zero_size_dim(%cst10x4x1xf32, %dim_0, %dim_1, %dim_2) : (tensor<10x4x1xf32>, index, index, index) -> ()
104117

105118
return
106119
}

0 commit comments

Comments
 (0)