diff --git a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp index 0d2e2ed85549d..ee47ebfeacefb 100644 --- a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp +++ b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp @@ -17,6 +17,7 @@ #include "mlir/IR/ValueRange.h" #include "mlir/Interfaces/CallInterfaces.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" +#include "mlir/Interfaces/FunctionInterfaces.h" #include "mlir/Support/LLVM.h" #include "llvm/ADT/STLExtras.h" #include "llvm/Support/DebugLog.h" @@ -505,12 +506,21 @@ AbstractSparseBackwardDataFlowAnalysis::visitOperation(Operation *op) { // If the call invokes an external function (or a function treated as // external due to config), defer to the corresponding extension hook. // By default, it just does `visitCallOperand` for all operands. + // + // If callable is a public function, treat it as external. + // This is because a public function has potential callers we can't + // visit, and thus we need to be conservative and consider all + // arguments live. OperandRange argOperands = call.getArgOperands(); MutableArrayRef argOpOperands = operandsToOpOperands(argOperands); Region *region = callable.getCallableRegion(); - if (!region || region->empty() || - !getSolverConfig().isInterprocedural()) { + auto isPublicFunction = [&]() { + auto funcOp = dyn_cast(callableOp); + return funcOp && funcOp.isPublic(); + }; + if (!getSolverConfig().isInterprocedural() || !region || + region->empty() || isPublicFunction()) { visitExternalCallImpl(call, operandLattices, resultLattices); return success(); } diff --git a/mlir/test/Transforms/remove-dead-values.mlir b/mlir/test/Transforms/remove-dead-values.mlir index fa2c145bd3701..8d50179c863a3 100644 --- a/mlir/test/Transforms/remove-dead-values.mlir +++ b/mlir/test/Transforms/remove-dead-values.mlir @@ -569,6 +569,18 @@ module @return_void_with_unused_argument { call @fn_return_void_with_unused_argument(%arg0, %unused) : (i32, memref<4xi32>) -> () return %unused : memref<4xi32> } + // the function signature is immutable because it is public. + func.func public @public_fn_with_unused_argument(%unused: i32) -> () { + return + } + // CHECK-LABEL: func.func @main2 + // CHECK: %[[UNUSED:.*]] = arith.constant 0 : i32 + // CHECK: call @public_fn_with_unused_argument(%[[UNUSED]]) : (i32) -> () + func.func @main2() -> () { + %zero = arith.constant 0 : i32 + call @public_fn_with_unused_argument(%zero) : (i32) -> () + return + } } // -----