Skip to content

Commit cbe7c49

Browse files
Hanumanth04Hanumanth Hanumantharayappa
andauthored
[mlir][memref] Fix runtime verification for memref.subview when size dimension value is 0 (llvm#164897)
Previously, the runtime verification pass would insert assertion statements with conditions that always evaluate to false for semantically valid `memref.subview` operations where one of the dimensions had a size of 0. The `memref.subview` runtime verification logic was unconditionally generating checks for the position of the last element (`offset + (size - 1) * stride`). When `size` is 0, this causes the assertion condition to always be false, leading to runtime failures even though the operation is semantically valid. This patch fixes the issue by making the `lastPos` check conditional. The offset is always verified, but the endpoint check is only performed when `size > 0` to avoid generating spurious assert statements. This issue was discovered through a LiteRT model, where a dynamic shape calculation resulted in a zero-sized dimension being passed to `memref.subview`. The following is a simplified IR snippet from the model. After running the runtime verification pass, an assertion that always fails is generated because the SSA value `%5` becomes 0. ```mlir module { memref.global "private" constant @__constant_2xi32 : memref<2xi32> = dense<-1> {alignment = 64 : i64} memref.global "private" constant @__constant_1xi32 : memref<1xi32> = dense<0> {alignment = 64 : i64} func.func @simpleRepro(%arg0: memref<10x4x1xf32, strided<[?, ?, ?], offset: ?>>) -> memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>> { %c2 = arith.constant 2 : index %c4 = arith.constant 4 : index %c1 = arith.constant 1 : index %c10 = arith.constant 10 : index %c0 = arith.constant 0 : index %c-1 = arith.constant -1 : index %0 = memref.get_global @__constant_1xi32 : memref<1xi32> %1 = memref.get_global @__constant_2xi32 : memref<2xi32> %alloca = memref.alloca() {alignment = 64 : i64} : memref<3xi32> %subview = memref.subview %alloca[0] [1] [1] : memref<3xi32> to memref<1xi32, strided<[1]>> memref.copy %0, %subview : memref<1xi32> to memref<1xi32, strided<[1]>> %subview_0 = memref.subview %alloca[1] [2] [1] : memref<3xi32> to memref<2xi32, strided<[1], offset: 1>> memref.copy %1, %subview_0 : memref<2xi32> to memref<2xi32, strided<[1], offset: 1>> %2 = memref.load %alloca[%c0] : memref<3xi32> %3 = index.casts %2 : i32 to index %4 = arith.cmpi eq, %3, %c-1 : index %5 = arith.select %4, %c10, %3 : index %6 = memref.load %alloca[%c1] : memref<3xi32> %7 = index.casts %6 : i32 to index %8 = arith.cmpi eq, %7, %c-1 : index %9 = arith.select %8, %c4, %7 : index %10 = memref.load %alloca[%c2] : memref<3xi32> %11 = index.casts %10 : i32 to index %12 = arith.cmpi eq, %11, %c-1 : index %13 = arith.select %12, %c1, %11 : index %subview_1 = memref.subview %arg0[0, 0, 0] [%5, %9, %13] [1, 1, 1] : memref<10x4x1xf32, strided<[?, ?, ?], offset: ?>> to memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>> return %subview_1 : memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>> } } ``` P.S. This is a similar issue to the one fixed for `tensor.extract_slice` in llvm#164878 --------- Co-authored-by: Hanumanth Hanumantharayappa <[email protected]>
1 parent a6788b5 commit cbe7c49

File tree

2 files changed

+48
-2
lines changed

2 files changed

+48
-2
lines changed

mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
1616
#include "mlir/Dialect/MemRef/IR/MemRef.h"
1717
#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
18+
#include "mlir/Dialect/SCF/IR/SCF.h"
1819
#include "mlir/Interfaces/RuntimeVerifiableOpInterface.h"
1920

2021
using namespace mlir;
@@ -273,7 +274,9 @@ struct SubViewOpInterface
273274
Value one = arith::ConstantIndexOp::create(builder, loc, 1);
274275
auto metadataOp =
275276
ExtractStridedMetadataOp::create(builder, loc, subView.getSource());
276-
for (int64_t i = 0, e = sourceType.getRank(); i < e; ++i) {
277+
for (int64_t i : llvm::seq<int64_t>(0, sourceType.getRank())) {
278+
// Reset insertion point to before the operation for each dimension
279+
builder.setInsertionPoint(subView);
277280
Value offset = getValueOrCreateConstantIndexOp(
278281
builder, loc, subView.getMixedOffsets()[i]);
279282
Value size = getValueOrCreateConstantIndexOp(builder, loc,
@@ -290,6 +293,16 @@ struct SubViewOpInterface
290293
std::to_string(i) +
291294
" is out-of-bounds"));
292295

296+
// Only verify if size > 0
297+
Value sizeIsNonZero = arith::CmpIOp::create(
298+
builder, loc, arith::CmpIPredicate::sgt, size, zero);
299+
300+
auto ifOp = scf::IfOp::create(builder, loc, builder.getI1Type(),
301+
sizeIsNonZero, /*withElseRegion=*/true);
302+
303+
// Populate the "then" region (for size > 0).
304+
builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
305+
293306
// Verify that slice does not run out-of-bounds.
294307
Value sizeMinusOne = arith::SubIOp::create(builder, loc, size, one);
295308
Value sizeMinusOneTimesStride =
@@ -298,8 +311,20 @@ struct SubViewOpInterface
298311
arith::AddIOp::create(builder, loc, offset, sizeMinusOneTimesStride);
299312
Value lastPosInBounds =
300313
generateInBoundsCheck(builder, loc, lastPos, zero, dimSize);
314+
315+
scf::YieldOp::create(builder, loc, lastPosInBounds);
316+
317+
// Populate the "else" region (for size == 0).
318+
builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
319+
Value trueVal =
320+
arith::ConstantOp::create(builder, loc, builder.getBoolAttr(true));
321+
scf::YieldOp::create(builder, loc, trueVal);
322+
323+
builder.setInsertionPointAfter(ifOp);
324+
Value finalCondition = ifOp.getResult(0);
325+
301326
cf::AssertOp::create(
302-
builder, loc, lastPosInBounds,
327+
builder, loc, finalCondition,
303328
generateErrorMessage(op,
304329
"subview runs out-of-bounds along dimension " +
305330
std::to_string(i)));

mlir/test/Integration/Dialect/MemRef/subview-runtime-verification.mlir

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
// RUN: -expand-strided-metadata \
33
// RUN: -lower-affine \
44
// RUN: -test-cf-assert \
5+
// RUN: -convert-scf-to-cf \
56
// RUN: -convert-to-llvm | \
67
// RUN: mlir-runner -e main -entry-point-result=void \
78
// RUN: -shared-libs=%mlir_runner_utils 2>&1 | \
@@ -11,6 +12,7 @@
1112
// RUN: -expand-strided-metadata \
1213
// RUN: -lower-affine \
1314
// RUN: -test-cf-assert \
15+
// RUN: -convert-scf-to-cf \
1416
// RUN: -convert-to-llvm="allow-pattern-rollback=0" \
1517
// RUN: -reconcile-unrealized-casts | \
1618
// RUN: mlir-runner -e main -entry-point-result=void \
@@ -38,6 +40,17 @@ func.func @subview_dynamic_rank_reduce(%memref: memref<?x4xf32>, %offset: index,
3840
return
3941
}
4042

43+
func.func @subview_zero_size_dim(%memref: memref<10x4x1xf32, strided<[?, ?, ?], offset: ?>>,
44+
%dim_0: index,
45+
%dim_1: index,
46+
%dim_2: index) {
47+
%subview = memref.subview %memref[0, 0, 0] [%dim_0, %dim_1, %dim_2] [1, 1, 1] :
48+
memref<10x4x1xf32, strided<[?, ?, ?], offset: ?>> to
49+
memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>>
50+
return
51+
}
52+
53+
4154
func.func @main() {
4255
%0 = arith.constant 0 : index
4356
%1 = arith.constant 1 : index
@@ -105,6 +118,14 @@ func.func @main() {
105118
// CHECK-NOT: ERROR: Runtime op verification failed
106119
func.call @subview_dynamic_rank_reduce(%alloca_4_dyn, %0, %1, %0) : (memref<?x4xf32>, index, index, index) -> ()
107120

121+
%alloca_10x4x1 = memref.alloca() : memref<10x4x1xf32>
122+
%alloca_10x4x1_dyn_stride = memref.cast %alloca_10x4x1 : memref<10x4x1xf32> to memref<10x4x1xf32, strided<[?, ?, ?], offset: ?>>
123+
// CHECK-NOT: ERROR: Runtime op verification failed
124+
%dim_0 = arith.constant 0 : index
125+
%dim_1 = arith.constant 4 : index
126+
%dim_2 = arith.constant 1 : index
127+
func.call @subview_zero_size_dim(%alloca_10x4x1_dyn_stride, %dim_0, %dim_1, %dim_2)
128+
: (memref<10x4x1xf32, strided<[?, ?, ?], offset: ?>>, index, index, index) -> ()
108129

109130
return
110131
}

0 commit comments

Comments
 (0)