diff --git a/mlir/lib/Transforms/RemoveDeadValues.cpp b/mlir/lib/Transforms/RemoveDeadValues.cpp index c20c54551cdf8..92c05d87a002d 100644 --- a/mlir/lib/Transforms/RemoveDeadValues.cpp +++ b/mlir/lib/Transforms/RemoveDeadValues.cpp @@ -305,6 +305,8 @@ static void processFuncOp(FunctionOpInterface funcOp, Operation *module, // since it forwards only to non-live value(s) (%1#1). Operation *lastReturnOp = funcOp.back().getTerminator(); size_t numReturns = lastReturnOp->getNumOperands(); + if (numReturns == 0) + return; BitVector nonLiveRets(numReturns, true); for (SymbolTable::SymbolUse use : uses) { Operation *callOp = use.getUser(); diff --git a/mlir/test/Transforms/remove-dead-values.mlir b/mlir/test/Transforms/remove-dead-values.mlir index e549926b90456..21d53b0742e07 100644 --- a/mlir/test/Transforms/remove-dead-values.mlir +++ b/mlir/test/Transforms/remove-dead-values.mlir @@ -446,6 +446,21 @@ func.func @kernel(%arg0: memref<18xf32>) { // ----- + +// CHECK-LABEL: llvm_unreachable +// CHECK-LABEL: @fn_with_llvm_unreachable +// CHECK-LABEL: @main +// CHECK: llvm.return +module @llvm_unreachable { + func.func private @fn_with_llvm_unreachable(%arg0: tensor<4x4xf32>) -> tensor<4x4xi1> { + llvm.unreachable + } + func.func private @main(%arg0: tensor<4x4xf32>) { + %0 = call @fn_with_llvm_unreachable(%arg0) : (tensor<4x4xf32>) -> tensor<4x4xi1> + llvm.return + } +} + // CHECK: func.func private @no_block_func_declaration() func.func private @no_block_func_declaration() -> ()