diff --git a/mlir/lib/Dialect/Linalg/Transforms/RuntimeOpVerification.cpp b/mlir/lib/Dialect/Linalg/Transforms/RuntimeOpVerification.cpp index 15eb51a6dcab2..181b4846835c0 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/RuntimeOpVerification.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/RuntimeOpVerification.cpp @@ -17,6 +17,7 @@ #include "mlir/Dialect/Index/IR/IndexOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Interfaces/RuntimeVerifiableOpInterface.h" @@ -43,6 +44,32 @@ struct StructuredOpInterface auto zero = arith::ConstantIndexOp::create(builder, loc, 0); auto one = arith::ConstantIndexOp::create(builder, loc, 1); + Value iterationDomainIsNonDegenerate; + for (auto [start, end] : llvm::zip(starts, ends)) { + auto startValue = getValueOrCreateConstantIndexOp(builder, loc, start); + auto endValue = getValueOrCreateConstantIndexOp(builder, loc, end); + + // Loop Trip count > 0 iff start < end + Value dimensionHasNonZeroTripCount = builder.create( + loc, index::IndexCmpPredicate::SLT, startValue, endValue); + + if (!iterationDomainIsNonDegenerate) { + iterationDomainIsNonDegenerate = dimensionHasNonZeroTripCount; + } else { + // Iteration domain is non-degenerate iff all dimensions have loop trip + // count > 0 + iterationDomainIsNonDegenerate = builder.create( + loc, iterationDomainIsNonDegenerate, dimensionHasNonZeroTripCount); + } + } + + if (!iterationDomainIsNonDegenerate) + return; + + auto ifOp = builder.create(loc, iterationDomainIsNonDegenerate, + /*withElseRegion=*/false); + builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); + // Subtract one from the loop ends before composing with the indexing map transform(ends, ends.begin(), [&](OpFoldResult end) { auto endValue = getValueOrCreateConstantIndexOp(builder, loc, end); @@ -110,6 +137,7 @@ struct StructuredOpInterface builder.createOrFold(loc, cmpOp, msg); } } + builder.setInsertionPointAfter(ifOp); } }; diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/runtime-verification.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/runtime-verification.mlir index 9f4393efc87bf..127ab70cb4539 100644 --- a/mlir/test/Integration/Dialect/Linalg/CPU/runtime-verification.mlir +++ b/mlir/test/Integration/Dialect/Linalg/CPU/runtime-verification.mlir @@ -103,6 +103,17 @@ func.func @main() { // CHECK: unexpected negative result on dimension #0 of input/output operand #0 func.call @reverse_from_3(%d5x) : (tensor) -> (tensor) + %c0x = arith.constant dense<1.0> : tensor<0xf32> + %d0x = tensor.cast %c0x : tensor<0xf32> to tensor + // CHECK-NOT: ERROR: Runtime op verification failed + func.call @fill_empty_1d(%d0x) : (tensor) -> (tensor) + + %c0x5 = arith.constant dense<0.0> : tensor<0x5xf32> + %d0x5 = tensor.cast %c0x5 : tensor<0x5xf32> to tensor + + // CHECK-NOT: ERROR: Runtime op verification failed + func.call @fill_empty_2d(%d0x5) : (tensor) -> (tensor) + return } @@ -297,3 +308,15 @@ func.func @reverse_from_3(%arg0: tensor) -> (tensor) { } -> tensor return %result : tensor } + +func.func @fill_empty_1d(%arg0: tensor) -> (tensor) { + %c0 = arith.constant 0.0 : f32 + %0 = linalg.fill ins(%c0 : f32) outs(%arg0 : tensor) -> tensor + return %0 : tensor +} + +func.func @fill_empty_2d(%arg0: tensor) -> (tensor) { + %c0 = arith.constant 0.0 : f32 + %0 = linalg.fill ins(%c0 : f32) outs(%arg0 : tensor) -> tensor + return %0 : tensor +}