Skip to content

Commit d08cbc1

Browse files
Hanumanth04Hanumanth Hanumantharayappa
andauthored
[mlir][linalg] Fix Linalg runtime verification pass to handle tensors with dimensions of size 0 (llvm#163791)
Runtime verification on Linalg structured ops unconditionally computed `end - 1` to determine the last iteration index before composing indexing maps. This caused spurious "negative index" assertion failures while operating on empty tensors (tensors with a dimension of size 0). The issue occurs because: 1. Empty tensors create loop ranges [0, 0) with zero trip count 2. Computing end - 1 = 0 - 1 = -1 creates a fictitious negative index 3. The negative index check triggers even though no loop iterations occur The fix is to guard all runtime verification with a check that ensures all loop ranges are non-empty (start < end) before performing any index arithmetic. Example MLIR that previously failed: ```mlir func.func @fill_empty() -> tensor<0xi32> { %c0 = arith.constant 0 : i32 %empty = tensor.empty() : tensor<0xi32> %filled = linalg.fill ins(%c0 : i32) outs(%empty : tensor<0xi32>) -> tensor<0xi32> return %filled : tensor<0xi32> } ``` --------- Co-authored-by: Hanumanth Hanumantharayappa <[email protected]>
1 parent b307347 commit d08cbc1

File tree

2 files changed

+51
-0
lines changed

2 files changed

+51
-0
lines changed

mlir/lib/Dialect/Linalg/Transforms/RuntimeOpVerification.cpp

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include "mlir/Dialect/Index/IR/IndexOps.h"
1818
#include "mlir/Dialect/Linalg/IR/Linalg.h"
1919
#include "mlir/Dialect/MemRef/IR/MemRef.h"
20+
#include "mlir/Dialect/SCF/IR/SCF.h"
2021
#include "mlir/Dialect/Tensor/IR/Tensor.h"
2122
#include "mlir/Interfaces/RuntimeVerifiableOpInterface.h"
2223

@@ -43,6 +44,32 @@ struct StructuredOpInterface
4344
auto zero = arith::ConstantIndexOp::create(builder, loc, 0);
4445
auto one = arith::ConstantIndexOp::create(builder, loc, 1);
4546

47+
Value iterationDomainIsNonDegenerate;
48+
for (auto [start, end] : llvm::zip(starts, ends)) {
49+
auto startValue = getValueOrCreateConstantIndexOp(builder, loc, start);
50+
auto endValue = getValueOrCreateConstantIndexOp(builder, loc, end);
51+
52+
// Loop Trip count > 0 iff start < end
53+
Value dimensionHasNonZeroTripCount = builder.create<index::CmpOp>(
54+
loc, index::IndexCmpPredicate::SLT, startValue, endValue);
55+
56+
if (!iterationDomainIsNonDegenerate) {
57+
iterationDomainIsNonDegenerate = dimensionHasNonZeroTripCount;
58+
} else {
59+
// Iteration domain is non-degenerate iff all dimensions have loop trip
60+
// count > 0
61+
iterationDomainIsNonDegenerate = builder.create<arith::AndIOp>(
62+
loc, iterationDomainIsNonDegenerate, dimensionHasNonZeroTripCount);
63+
}
64+
}
65+
66+
if (!iterationDomainIsNonDegenerate)
67+
return;
68+
69+
auto ifOp = builder.create<scf::IfOp>(loc, iterationDomainIsNonDegenerate,
70+
/*withElseRegion=*/false);
71+
builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
72+
4673
// Subtract one from the loop ends before composing with the indexing map
4774
transform(ends, ends.begin(), [&](OpFoldResult end) {
4875
auto endValue = getValueOrCreateConstantIndexOp(builder, loc, end);
@@ -110,6 +137,7 @@ struct StructuredOpInterface
110137
builder.createOrFold<cf::AssertOp>(loc, cmpOp, msg);
111138
}
112139
}
140+
builder.setInsertionPointAfter(ifOp);
113141
}
114142
};
115143

mlir/test/Integration/Dialect/Linalg/CPU/runtime-verification.mlir

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,17 @@ func.func @main() {
103103
// CHECK: unexpected negative result on dimension #0 of input/output operand #0
104104
func.call @reverse_from_3(%d5x) : (tensor<?xf32>) -> (tensor<?xf32>)
105105

106+
%c0x = arith.constant dense<1.0> : tensor<0xf32>
107+
%d0x = tensor.cast %c0x : tensor<0xf32> to tensor<?xf32>
108+
// CHECK-NOT: ERROR: Runtime op verification failed
109+
func.call @fill_empty_1d(%d0x) : (tensor<?xf32>) -> (tensor<?xf32>)
110+
111+
%c0x5 = arith.constant dense<0.0> : tensor<0x5xf32>
112+
%d0x5 = tensor.cast %c0x5 : tensor<0x5xf32> to tensor<?x?xf32>
113+
114+
// CHECK-NOT: ERROR: Runtime op verification failed
115+
func.call @fill_empty_2d(%d0x5) : (tensor<?x?xf32>) -> (tensor<?x?xf32>)
116+
106117
return
107118
}
108119

@@ -297,3 +308,15 @@ func.func @reverse_from_3(%arg0: tensor<?xf32>) -> (tensor<?xf32>) {
297308
} -> tensor<?xf32>
298309
return %result : tensor<?xf32>
299310
}
311+
312+
func.func @fill_empty_1d(%arg0: tensor<?xf32>) -> (tensor<?xf32>) {
313+
%c0 = arith.constant 0.0 : f32
314+
%0 = linalg.fill ins(%c0 : f32) outs(%arg0 : tensor<?xf32>) -> tensor<?xf32>
315+
return %0 : tensor<?xf32>
316+
}
317+
318+
func.func @fill_empty_2d(%arg0: tensor<?x?xf32>) -> (tensor<?x?xf32>) {
319+
%c0 = arith.constant 0.0 : f32
320+
%0 = linalg.fill ins(%c0 : f32) outs(%arg0 : tensor<?x?xf32>) -> tensor<?x?xf32>
321+
return %0 : tensor<?x?xf32>
322+
}

0 commit comments

Comments
 (0)