Skip to content

Conversation

@navyxliu
Copy link
Contributor

@navyxliu navyxliu commented Oct 6, 2025

Problem

Liveness analysis is inter-procedural. If there are some unused arguments in a public function, they propagate to callers. From the perspective of RemoveDeadValues, the signature of a public function is immutable. It can't cope with this situation. One side, it deletes outgoing arguments, on the other side it keeps the function intact.

Solution

We deploy a pre-process pass called 'privatize-public-function'. It attempts to split interface and implementation if and only if

  1. liveness is inter-procedural
  2. the target of callOp is a public function
  3. there is at least one argument is NonLive.

We split interface and implementation.
the original function is interface, the function body is just a forwarding call to the implementation. The new function is private as implementation. It derives the function body.

If IR has changed, we invalidate LivenessAnalysis and recompute for the following remove-dead-values.

Test plan

./bin/llvm-lit -v ../mlir/test/Transforms/remove-dead-values.mlir
before:

./input.mlir:13:5: error: null operand found
    call @immutable_fn_return_void_with_unused_argument(%arg0, %zero) : (i32, i32) -> ()
    ^
./input.mlir:13:5: note: see current operation: "func.call"(%arg0, <<NULL VALUE>>) <{callee = @immutable_fn_return_void_with_unused_argument}> : (i32, <<NULL TYPE>>) -> ()
after: pass

Alternatives

  1. createDummyArgument replaces arguments with dummies. eg. If the type can create a value with zero initializer, such as 'T x{0}'. The downside of this method is that we may not construct trivial values for all types, in particular the aggregate types.

  2. We add another DenseSet called 'liveSet'. The initial values are the arguments of public functions. It propagates liveness backward just like Liveness analysis.

Xin Liu and others added 4 commits October 5, 2025 20:46
This diff also changes traversal order from forward to backward for
region/block/ops. This order guanratees Liveness updates at a callsite
can propagates to the defs of arguments.

```
./bin/llvm-lit -v ../mlir/test/Transforms/remove-dead-values.mlir
```
@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir labels Oct 6, 2025
@llvmbot
Copy link
Member

llvmbot commented Oct 6, 2025

@llvm/pr-subscribers-mlir

Author: xin liu (navyxliu)

Changes

This diff is REDO of #160242.

Problem

Liveness analysis is inter-procedural. If there are some unused arguments in a public function, they propagate to callers. From the perspective of RemoveDeadValues, the signature of a public function is immutable. It can't cope with this situation. One side, it deletes outgoing arguments, on the other side it keeps the function intact.

Solution

We deploy two methods to fix this bug.

  1. createDummyArgument replaces an arguments with dummies. eg. If type can create a value with zero initializer, such as 'T x{0}'.

  2. as a fallback, we add another DenseSet called 'liveSet'. The initial values are the arguments of public functions. It propagates liveness backward just like Liveness analysis.

Test plan

./bin/llvm-lit -v ../mlir/test/Transforms/remove-dead-values.mlir
before:

./input.mlir:13:5: error: null operand found
    call @<!-- -->immutable_fn_return_void_with_unused_argument(%arg0, %zero) : (i32, i32) -&gt; ()
    ^
./input.mlir:13:5: note: see current operation: "func.call"(%arg0, &lt;&lt;NULL VALUE&gt;&gt;) &lt;{callee = @<!-- -->immutable_fn_return_void_with_unused_argument}&gt; : (i32, &lt;&lt;NULL TYPE&gt;&gt;) -&gt; ()
after: pass

Full diff: https://github.com/llvm/llvm-project/pull/162038.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Analysis/DataFlow/LivenessAnalysis.h (+3)
  • (modified) mlir/include/mlir/IR/Visitors.h (+14)
  • (modified) mlir/lib/Transforms/RemoveDeadValues.cpp (+95-19)
  • (modified) mlir/test/Transforms/remove-dead-values.mlir (+18)
diff --git a/mlir/include/mlir/Analysis/DataFlow/LivenessAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/LivenessAnalysis.h
index cf1fd6e2d48ca..be7e027b95f64 100644
--- a/mlir/include/mlir/Analysis/DataFlow/LivenessAnalysis.h
+++ b/mlir/include/mlir/Analysis/DataFlow/LivenessAnalysis.h
@@ -102,6 +102,9 @@ struct RunLivenessAnalysis {
 
   const Liveness *getLiveness(Value val);
 
+  /// Return the configuration of the solver used for this analysis.
+  const DataFlowConfig &getSolverConfig() const { return solver.getConfig(); }
+
 private:
   /// Stores the result of the liveness analysis that was run.
   DataFlowSolver solver;
diff --git a/mlir/include/mlir/IR/Visitors.h b/mlir/include/mlir/IR/Visitors.h
index 893f66ae33deb..5766d262796d6 100644
--- a/mlir/include/mlir/IR/Visitors.h
+++ b/mlir/include/mlir/IR/Visitors.h
@@ -39,6 +39,20 @@ struct ForwardIterator {
   }
 };
 
+/// This iterator enumerates the elements in "backward" order.
+struct BackwardIterator {
+  template <typename T>
+  static auto makeIterable(T &range) {
+    if constexpr (std::is_same<T, Operation>()) {
+      /// Make operations iterable: return the list of regions.
+      return llvm::reverse(range.getRegions());
+    } else {
+      /// Regions and block are already iterable.
+      return llvm::reverse(range);
+    }
+  }
+};
+
 /// A utility class to encode the current walk stage for "generic" walkers.
 /// When walking an operation, we can either choose a Pre/Post order walker
 /// which invokes the callback on an operation before/after all its attached
diff --git a/mlir/lib/Transforms/RemoveDeadValues.cpp b/mlir/lib/Transforms/RemoveDeadValues.cpp
index e0c65b0e09774..ed5c6a8d2ead0 100644
--- a/mlir/lib/Transforms/RemoveDeadValues.cpp
+++ b/mlir/lib/Transforms/RemoveDeadValues.cpp
@@ -33,6 +33,7 @@
 
 #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
 #include "mlir/Analysis/DataFlow/LivenessAnalysis.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/Dialect.h"
@@ -118,8 +119,13 @@ struct RDVFinalCleanupList {
 /// Return true iff at least one value in `values` is live, given the liveness
 /// information in `la`.
 static bool hasLive(ValueRange values, const DenseSet<Value> &nonLiveSet,
-                    RunLivenessAnalysis &la) {
+                    const DenseSet<Value> &liveSet, RunLivenessAnalysis &la) {
   for (Value value : values) {
+    if (liveSet.contains(value)) {
+      LDBG() << "Value " << value << " is marked live by CallOp";
+      return true;
+    }
+
     if (nonLiveSet.contains(value)) {
       LDBG() << "Value " << value << " is already marked non-live (dead)";
       continue;
@@ -144,6 +150,7 @@ static bool hasLive(ValueRange values, const DenseSet<Value> &nonLiveSet,
 /// Return a BitVector of size `values.size()` where its i-th bit is 1 iff the
 /// i-th value in `values` is live, given the liveness information in `la`.
 static BitVector markLives(ValueRange values, const DenseSet<Value> &nonLiveSet,
+                           const DenseSet<Value> &liveSet,
                            RunLivenessAnalysis &la) {
   BitVector lives(values.size(), true);
 
@@ -154,7 +161,9 @@ static BitVector markLives(ValueRange values, const DenseSet<Value> &nonLiveSet,
              << " is already marked non-live (dead) at index " << index;
       continue;
     }
-
+    if (liveSet.contains(value)) {
+      continue;
+    }
     const Liveness *liveness = la.getLiveness(value);
     // It is important to note that when `liveness` is null, we can't tell if
     // `value` is live or not. So, the safe option is to consider it live. Also,
@@ -259,8 +268,19 @@ static SmallVector<OpOperand *> operandsToOpOperands(OperandRange operands) {
 ///   - Return-like
 static void processSimpleOp(Operation *op, RunLivenessAnalysis &la,
                             DenseSet<Value> &nonLiveSet,
-                            RDVFinalCleanupList &cl) {
-  if (!isMemoryEffectFree(op) || hasLive(op->getResults(), nonLiveSet, la)) {
+                            DenseSet<Value> &liveSet, RDVFinalCleanupList &cl) {
+  for (Value val : op->getResults()) {
+    if (liveSet.contains(val)) {
+      LDBG() << "Simple op is used by a public function, "
+                "preserving it: "
+             << OpWithFlags(op, OpPrintingFlags().skipRegions());
+      liveSet.insert_range(op->getOperands());
+      return;
+    }
+  }
+
+  if (!isMemoryEffectFree(op) ||
+      hasLive(op->getResults(), nonLiveSet, liveSet, la)) {
     LDBG() << "Simple op is not memory effect free or has live results, "
               "preserving it: "
            << OpWithFlags(op, OpPrintingFlags().skipRegions());
@@ -288,7 +308,7 @@ static void processSimpleOp(Operation *op, RunLivenessAnalysis &la,
 ///   (6) Marking all its results as non-live values.
 static void processFuncOp(FunctionOpInterface funcOp, Operation *module,
                           RunLivenessAnalysis &la, DenseSet<Value> &nonLiveSet,
-                          RDVFinalCleanupList &cl) {
+                          DenseSet<Value> &liveSet, RDVFinalCleanupList &cl) {
   LDBG() << "Processing function op: "
          << OpWithFlags(funcOp, OpPrintingFlags().skipRegions());
   if (funcOp.isPublic() || funcOp.isExternal()) {
@@ -299,7 +319,7 @@ static void processFuncOp(FunctionOpInterface funcOp, Operation *module,
 
   // Get the list of unnecessary (non-live) arguments in `nonLiveArgs`.
   SmallVector<Value> arguments(funcOp.getArguments());
-  BitVector nonLiveArgs = markLives(arguments, nonLiveSet, la);
+  BitVector nonLiveArgs = markLives(arguments, nonLiveSet, liveSet, la);
   nonLiveArgs = nonLiveArgs.flip();
 
   // Do (1).
@@ -352,7 +372,8 @@ static void processFuncOp(FunctionOpInterface funcOp, Operation *module,
   for (SymbolTable::SymbolUse use : uses) {
     Operation *callOp = use.getUser();
     assert(isa<CallOpInterface>(callOp) && "expected a call-like user");
-    BitVector liveCallRets = markLives(callOp->getResults(), nonLiveSet, la);
+    BitVector liveCallRets =
+        markLives(callOp->getResults(), nonLiveSet, liveSet, la);
     nonLiveRets &= liveCallRets.flip();
   }
 
@@ -379,6 +400,56 @@ static void processFuncOp(FunctionOpInterface funcOp, Operation *module,
   }
 }
 
+// Create a cheaper value with the same type of oldVal in front of CallOp.
+static Value createDummyArgument(CallOpInterface callOp, Value oldVal) {
+  OpBuilder builder(callOp.getOperation());
+  Type type = oldVal.getType();
+
+  // Create zero constant for any supported type
+  if (TypedAttr zeroAttr = builder.getZeroAttr(type)) {
+    return builder.create<arith::ConstantOp>(oldVal.getLoc(), type, zeroAttr);
+  }
+  return {};
+}
+
+static void processCallOp(CallOpInterface callOp, Operation *module,
+                          RunLivenessAnalysis &la, DenseSet<Value> &nonLiveSet,
+                          DenseSet<Value> &liveSet) {
+  if (!la.getSolverConfig().isInterprocedural())
+    return;
+
+  Operation *callableOp = callOp.resolveCallable();
+  auto funcOp = dyn_cast<FunctionOpInterface>(callableOp);
+  if (!funcOp || !funcOp.isPublic()) {
+    return;
+  }
+
+  LDBG() << "processCallOp to a public function: " << funcOp.getName();
+  // Get the list of unnecessary (non-live) arguments in `nonLiveArgs`.
+  SmallVector<Value> arguments(funcOp.getArguments());
+  BitVector nonLiveArgs = markLives(arguments, nonLiveSet, liveSet, la);
+  nonLiveArgs = nonLiveArgs.flip();
+
+  if (nonLiveArgs.count() > 0) {
+    LDBG() << funcOp.getName() << " contains NonLive arguments";
+    // The number of operands in the call op may not match the number of
+    // arguments in the func op.
+    SmallVector<OpOperand *> callOpOperands =
+        operandsToOpOperands(callOp.getArgOperands());
+
+    for (int index : nonLiveArgs.set_bits()) {
+      OpOperand *operand = callOpOperands[index];
+      Value oldVal = operand->get();
+      if (Value dummy = createDummyArgument(callOp, oldVal)) {
+        callOp->setOperand(operand->getOperandNumber(), dummy);
+        nonLiveSet.insert(oldVal);
+      } else {
+        liveSet.insert(oldVal);
+      }
+    }
+  }
+}
+
 /// Process a region branch operation `regionBranchOp` using the liveness
 /// information in `la`. The processing involves two scenarios:
 ///
@@ -411,12 +482,14 @@ static void processFuncOp(FunctionOpInterface funcOp, Operation *module,
 static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
                                   RunLivenessAnalysis &la,
                                   DenseSet<Value> &nonLiveSet,
+                                  DenseSet<Value> &liveSet,
                                   RDVFinalCleanupList &cl) {
   LDBG() << "Processing region branch op: "
          << OpWithFlags(regionBranchOp, OpPrintingFlags().skipRegions());
   // Mark live results of `regionBranchOp` in `liveResults`.
   auto markLiveResults = [&](BitVector &liveResults) {
-    liveResults = markLives(regionBranchOp->getResults(), nonLiveSet, la);
+    liveResults =
+        markLives(regionBranchOp->getResults(), nonLiveSet, liveSet, la);
   };
 
   // Mark live arguments in the regions of `regionBranchOp` in `liveArgs`.
@@ -425,7 +498,7 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
       if (region.empty())
         continue;
       SmallVector<Value> arguments(region.front().getArguments());
-      BitVector regionLiveArgs = markLives(arguments, nonLiveSet, la);
+      BitVector regionLiveArgs = markLives(arguments, nonLiveSet, liveSet, la);
       liveArgs[&region] = regionLiveArgs;
     }
   };
@@ -619,7 +692,7 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
   // attributed to something else.
   // Do (1') and (2').
   if (isMemoryEffectFree(regionBranchOp.getOperation()) &&
-      !hasLive(regionBranchOp->getResults(), nonLiveSet, la)) {
+      !hasLive(regionBranchOp->getResults(), nonLiveSet, liveSet, la)) {
     cl.operations.push_back(regionBranchOp.getOperation());
     return;
   }
@@ -698,7 +771,7 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
 
 static void processBranchOp(BranchOpInterface branchOp, RunLivenessAnalysis &la,
                             DenseSet<Value> &nonLiveSet,
-                            RDVFinalCleanupList &cl) {
+                            DenseSet<Value> &liveSet, RDVFinalCleanupList &cl) {
   LDBG() << "Processing branch op: " << *branchOp;
   unsigned numSuccessors = branchOp->getNumSuccessors();
 
@@ -716,7 +789,7 @@ static void processBranchOp(BranchOpInterface branchOp, RunLivenessAnalysis &la,
 
     // Do (2)
     BitVector successorNonLive =
-        markLives(operandValues, nonLiveSet, la).flip();
+        markLives(operandValues, nonLiveSet, liveSet, la).flip();
     collectNonLiveValues(nonLiveSet, successorBlock->getArguments(),
                          successorNonLive);
 
@@ -876,26 +949,29 @@ void RemoveDeadValues::runOnOperation() {
   // Tracks values eligible for erasure - complements liveness analysis to
   // identify "droppable" values.
   DenseSet<Value> deadVals;
+  // mark outgoing arguments to a public function LIVE. We also propagate
+  // liveness backward.
+  DenseSet<Value> liveVals;
 
   // Maintains a list of Ops, values, branches, etc., slated for cleanup at the
   // end of this pass.
   RDVFinalCleanupList finalCleanupList;
 
-  module->walk([&](Operation *op) {
+  module->walk<WalkOrder::PostOrder, BackwardIterator>([&](Operation *op) {
     if (auto funcOp = dyn_cast<FunctionOpInterface>(op)) {
-      processFuncOp(funcOp, module, la, deadVals, finalCleanupList);
+      processFuncOp(funcOp, module, la, deadVals, liveVals, finalCleanupList);
     } else if (auto regionBranchOp = dyn_cast<RegionBranchOpInterface>(op)) {
-      processRegionBranchOp(regionBranchOp, la, deadVals, finalCleanupList);
+      processRegionBranchOp(regionBranchOp, la, deadVals, liveVals,
+                            finalCleanupList);
     } else if (auto branchOp = dyn_cast<BranchOpInterface>(op)) {
-      processBranchOp(branchOp, la, deadVals, finalCleanupList);
+      processBranchOp(branchOp, la, deadVals, liveVals, finalCleanupList);
     } else if (op->hasTrait<::mlir::OpTrait::IsTerminator>()) {
       // Nothing to do here because this is a terminator op and it should be
       // honored with respect to its parent
     } else if (isa<CallOpInterface>(op)) {
-      // Nothing to do because this op is associated with a function op and gets
-      // cleaned when the latter is cleaned.
+      processCallOp(cast<CallOpInterface>(op), module, la, deadVals, liveVals);
     } else {
-      processSimpleOp(op, la, deadVals, finalCleanupList);
+      processSimpleOp(op, la, deadVals, liveVals, finalCleanupList);
     }
   });
 
diff --git a/mlir/test/Transforms/remove-dead-values.mlir b/mlir/test/Transforms/remove-dead-values.mlir
index 56449469dc29f..ebb76a8835ceb 100644
--- a/mlir/test/Transforms/remove-dead-values.mlir
+++ b/mlir/test/Transforms/remove-dead-values.mlir
@@ -569,6 +569,24 @@ module @return_void_with_unused_argument {
     call @fn_return_void_with_unused_argument(%arg0, %unused) : (i32, memref<4xi32>) -> ()
     return %unused : memref<4xi32>
   }
+
+  // the function signature is immutable because it is public.
+  func.func public @immutable_fn_with_unused_argument(%arg0: i32, %arg1: memref<4xf32>) -> () {
+    return
+  }
+
+  // CHECK-LABEL: func.func @main2
+  // CHECK: %[[MEM:.*]] = memref.alloc() : memref<4xf32>
+  // CHECK: %[[UNUSED:.*]] = arith.constant 0 : i32
+  // CHECK: call @immutable_fn_with_unused_argument(%[[UNUSED]], %[[MEM]]) : (i32, memref<4xf32>) -> ()
+  func.func @main2() -> () {
+    %one = arith.constant 1 : i32
+    %scalar = arith.addi %one, %one: i32
+    %mem = memref.alloc() : memref<4xf32>
+
+    call @immutable_fn_with_unused_argument(%scalar, %mem) : (i32, memref<4xf32>) -> ()
+    return
+  }
 }
 
 // -----

@llvmbot
Copy link
Member

llvmbot commented Oct 6, 2025

@llvm/pr-subscribers-mlir-core

Author: xin liu (navyxliu)

Changes

This diff is REDO of #160242.

Problem

Liveness analysis is inter-procedural. If there are some unused arguments in a public function, they propagate to callers. From the perspective of RemoveDeadValues, the signature of a public function is immutable. It can't cope with this situation. One side, it deletes outgoing arguments, on the other side it keeps the function intact.

Solution

We deploy two methods to fix this bug.

  1. createDummyArgument replaces an arguments with dummies. eg. If type can create a value with zero initializer, such as 'T x{0}'.

  2. as a fallback, we add another DenseSet called 'liveSet'. The initial values are the arguments of public functions. It propagates liveness backward just like Liveness analysis.

Test plan

./bin/llvm-lit -v ../mlir/test/Transforms/remove-dead-values.mlir
before:

./input.mlir:13:5: error: null operand found
    call @<!-- -->immutable_fn_return_void_with_unused_argument(%arg0, %zero) : (i32, i32) -&gt; ()
    ^
./input.mlir:13:5: note: see current operation: "func.call"(%arg0, &lt;&lt;NULL VALUE&gt;&gt;) &lt;{callee = @<!-- -->immutable_fn_return_void_with_unused_argument}&gt; : (i32, &lt;&lt;NULL TYPE&gt;&gt;) -&gt; ()
after: pass

Full diff: https://github.com/llvm/llvm-project/pull/162038.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Analysis/DataFlow/LivenessAnalysis.h (+3)
  • (modified) mlir/include/mlir/IR/Visitors.h (+14)
  • (modified) mlir/lib/Transforms/RemoveDeadValues.cpp (+95-19)
  • (modified) mlir/test/Transforms/remove-dead-values.mlir (+18)
diff --git a/mlir/include/mlir/Analysis/DataFlow/LivenessAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/LivenessAnalysis.h
index cf1fd6e2d48ca..be7e027b95f64 100644
--- a/mlir/include/mlir/Analysis/DataFlow/LivenessAnalysis.h
+++ b/mlir/include/mlir/Analysis/DataFlow/LivenessAnalysis.h
@@ -102,6 +102,9 @@ struct RunLivenessAnalysis {
 
   const Liveness *getLiveness(Value val);
 
+  /// Return the configuration of the solver used for this analysis.
+  const DataFlowConfig &getSolverConfig() const { return solver.getConfig(); }
+
 private:
   /// Stores the result of the liveness analysis that was run.
   DataFlowSolver solver;
diff --git a/mlir/include/mlir/IR/Visitors.h b/mlir/include/mlir/IR/Visitors.h
index 893f66ae33deb..5766d262796d6 100644
--- a/mlir/include/mlir/IR/Visitors.h
+++ b/mlir/include/mlir/IR/Visitors.h
@@ -39,6 +39,20 @@ struct ForwardIterator {
   }
 };
 
+/// This iterator enumerates the elements in "backward" order.
+struct BackwardIterator {
+  template <typename T>
+  static auto makeIterable(T &range) {
+    if constexpr (std::is_same<T, Operation>()) {
+      /// Make operations iterable: return the list of regions.
+      return llvm::reverse(range.getRegions());
+    } else {
+      /// Regions and block are already iterable.
+      return llvm::reverse(range);
+    }
+  }
+};
+
 /// A utility class to encode the current walk stage for "generic" walkers.
 /// When walking an operation, we can either choose a Pre/Post order walker
 /// which invokes the callback on an operation before/after all its attached
diff --git a/mlir/lib/Transforms/RemoveDeadValues.cpp b/mlir/lib/Transforms/RemoveDeadValues.cpp
index e0c65b0e09774..ed5c6a8d2ead0 100644
--- a/mlir/lib/Transforms/RemoveDeadValues.cpp
+++ b/mlir/lib/Transforms/RemoveDeadValues.cpp
@@ -33,6 +33,7 @@
 
 #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
 #include "mlir/Analysis/DataFlow/LivenessAnalysis.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/Dialect.h"
@@ -118,8 +119,13 @@ struct RDVFinalCleanupList {
 /// Return true iff at least one value in `values` is live, given the liveness
 /// information in `la`.
 static bool hasLive(ValueRange values, const DenseSet<Value> &nonLiveSet,
-                    RunLivenessAnalysis &la) {
+                    const DenseSet<Value> &liveSet, RunLivenessAnalysis &la) {
   for (Value value : values) {
+    if (liveSet.contains(value)) {
+      LDBG() << "Value " << value << " is marked live by CallOp";
+      return true;
+    }
+
     if (nonLiveSet.contains(value)) {
       LDBG() << "Value " << value << " is already marked non-live (dead)";
       continue;
@@ -144,6 +150,7 @@ static bool hasLive(ValueRange values, const DenseSet<Value> &nonLiveSet,
 /// Return a BitVector of size `values.size()` where its i-th bit is 1 iff the
 /// i-th value in `values` is live, given the liveness information in `la`.
 static BitVector markLives(ValueRange values, const DenseSet<Value> &nonLiveSet,
+                           const DenseSet<Value> &liveSet,
                            RunLivenessAnalysis &la) {
   BitVector lives(values.size(), true);
 
@@ -154,7 +161,9 @@ static BitVector markLives(ValueRange values, const DenseSet<Value> &nonLiveSet,
              << " is already marked non-live (dead) at index " << index;
       continue;
     }
-
+    if (liveSet.contains(value)) {
+      continue;
+    }
     const Liveness *liveness = la.getLiveness(value);
     // It is important to note that when `liveness` is null, we can't tell if
     // `value` is live or not. So, the safe option is to consider it live. Also,
@@ -259,8 +268,19 @@ static SmallVector<OpOperand *> operandsToOpOperands(OperandRange operands) {
 ///   - Return-like
 static void processSimpleOp(Operation *op, RunLivenessAnalysis &la,
                             DenseSet<Value> &nonLiveSet,
-                            RDVFinalCleanupList &cl) {
-  if (!isMemoryEffectFree(op) || hasLive(op->getResults(), nonLiveSet, la)) {
+                            DenseSet<Value> &liveSet, RDVFinalCleanupList &cl) {
+  for (Value val : op->getResults()) {
+    if (liveSet.contains(val)) {
+      LDBG() << "Simple op is used by a public function, "
+                "preserving it: "
+             << OpWithFlags(op, OpPrintingFlags().skipRegions());
+      liveSet.insert_range(op->getOperands());
+      return;
+    }
+  }
+
+  if (!isMemoryEffectFree(op) ||
+      hasLive(op->getResults(), nonLiveSet, liveSet, la)) {
     LDBG() << "Simple op is not memory effect free or has live results, "
               "preserving it: "
            << OpWithFlags(op, OpPrintingFlags().skipRegions());
@@ -288,7 +308,7 @@ static void processSimpleOp(Operation *op, RunLivenessAnalysis &la,
 ///   (6) Marking all its results as non-live values.
 static void processFuncOp(FunctionOpInterface funcOp, Operation *module,
                           RunLivenessAnalysis &la, DenseSet<Value> &nonLiveSet,
-                          RDVFinalCleanupList &cl) {
+                          DenseSet<Value> &liveSet, RDVFinalCleanupList &cl) {
   LDBG() << "Processing function op: "
          << OpWithFlags(funcOp, OpPrintingFlags().skipRegions());
   if (funcOp.isPublic() || funcOp.isExternal()) {
@@ -299,7 +319,7 @@ static void processFuncOp(FunctionOpInterface funcOp, Operation *module,
 
   // Get the list of unnecessary (non-live) arguments in `nonLiveArgs`.
   SmallVector<Value> arguments(funcOp.getArguments());
-  BitVector nonLiveArgs = markLives(arguments, nonLiveSet, la);
+  BitVector nonLiveArgs = markLives(arguments, nonLiveSet, liveSet, la);
   nonLiveArgs = nonLiveArgs.flip();
 
   // Do (1).
@@ -352,7 +372,8 @@ static void processFuncOp(FunctionOpInterface funcOp, Operation *module,
   for (SymbolTable::SymbolUse use : uses) {
     Operation *callOp = use.getUser();
     assert(isa<CallOpInterface>(callOp) && "expected a call-like user");
-    BitVector liveCallRets = markLives(callOp->getResults(), nonLiveSet, la);
+    BitVector liveCallRets =
+        markLives(callOp->getResults(), nonLiveSet, liveSet, la);
     nonLiveRets &= liveCallRets.flip();
   }
 
@@ -379,6 +400,56 @@ static void processFuncOp(FunctionOpInterface funcOp, Operation *module,
   }
 }
 
+// Create a cheaper value with the same type of oldVal in front of CallOp.
+static Value createDummyArgument(CallOpInterface callOp, Value oldVal) {
+  OpBuilder builder(callOp.getOperation());
+  Type type = oldVal.getType();
+
+  // Create zero constant for any supported type
+  if (TypedAttr zeroAttr = builder.getZeroAttr(type)) {
+    return builder.create<arith::ConstantOp>(oldVal.getLoc(), type, zeroAttr);
+  }
+  return {};
+}
+
+static void processCallOp(CallOpInterface callOp, Operation *module,
+                          RunLivenessAnalysis &la, DenseSet<Value> &nonLiveSet,
+                          DenseSet<Value> &liveSet) {
+  if (!la.getSolverConfig().isInterprocedural())
+    return;
+
+  Operation *callableOp = callOp.resolveCallable();
+  auto funcOp = dyn_cast<FunctionOpInterface>(callableOp);
+  if (!funcOp || !funcOp.isPublic()) {
+    return;
+  }
+
+  LDBG() << "processCallOp to a public function: " << funcOp.getName();
+  // Get the list of unnecessary (non-live) arguments in `nonLiveArgs`.
+  SmallVector<Value> arguments(funcOp.getArguments());
+  BitVector nonLiveArgs = markLives(arguments, nonLiveSet, liveSet, la);
+  nonLiveArgs = nonLiveArgs.flip();
+
+  if (nonLiveArgs.count() > 0) {
+    LDBG() << funcOp.getName() << " contains NonLive arguments";
+    // The number of operands in the call op may not match the number of
+    // arguments in the func op.
+    SmallVector<OpOperand *> callOpOperands =
+        operandsToOpOperands(callOp.getArgOperands());
+
+    for (int index : nonLiveArgs.set_bits()) {
+      OpOperand *operand = callOpOperands[index];
+      Value oldVal = operand->get();
+      if (Value dummy = createDummyArgument(callOp, oldVal)) {
+        callOp->setOperand(operand->getOperandNumber(), dummy);
+        nonLiveSet.insert(oldVal);
+      } else {
+        liveSet.insert(oldVal);
+      }
+    }
+  }
+}
+
 /// Process a region branch operation `regionBranchOp` using the liveness
 /// information in `la`. The processing involves two scenarios:
 ///
@@ -411,12 +482,14 @@ static void processFuncOp(FunctionOpInterface funcOp, Operation *module,
 static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
                                   RunLivenessAnalysis &la,
                                   DenseSet<Value> &nonLiveSet,
+                                  DenseSet<Value> &liveSet,
                                   RDVFinalCleanupList &cl) {
   LDBG() << "Processing region branch op: "
          << OpWithFlags(regionBranchOp, OpPrintingFlags().skipRegions());
   // Mark live results of `regionBranchOp` in `liveResults`.
   auto markLiveResults = [&](BitVector &liveResults) {
-    liveResults = markLives(regionBranchOp->getResults(), nonLiveSet, la);
+    liveResults =
+        markLives(regionBranchOp->getResults(), nonLiveSet, liveSet, la);
   };
 
   // Mark live arguments in the regions of `regionBranchOp` in `liveArgs`.
@@ -425,7 +498,7 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
       if (region.empty())
         continue;
       SmallVector<Value> arguments(region.front().getArguments());
-      BitVector regionLiveArgs = markLives(arguments, nonLiveSet, la);
+      BitVector regionLiveArgs = markLives(arguments, nonLiveSet, liveSet, la);
       liveArgs[&region] = regionLiveArgs;
     }
   };
@@ -619,7 +692,7 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
   // attributed to something else.
   // Do (1') and (2').
   if (isMemoryEffectFree(regionBranchOp.getOperation()) &&
-      !hasLive(regionBranchOp->getResults(), nonLiveSet, la)) {
+      !hasLive(regionBranchOp->getResults(), nonLiveSet, liveSet, la)) {
     cl.operations.push_back(regionBranchOp.getOperation());
     return;
   }
@@ -698,7 +771,7 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
 
 static void processBranchOp(BranchOpInterface branchOp, RunLivenessAnalysis &la,
                             DenseSet<Value> &nonLiveSet,
-                            RDVFinalCleanupList &cl) {
+                            DenseSet<Value> &liveSet, RDVFinalCleanupList &cl) {
   LDBG() << "Processing branch op: " << *branchOp;
   unsigned numSuccessors = branchOp->getNumSuccessors();
 
@@ -716,7 +789,7 @@ static void processBranchOp(BranchOpInterface branchOp, RunLivenessAnalysis &la,
 
     // Do (2)
     BitVector successorNonLive =
-        markLives(operandValues, nonLiveSet, la).flip();
+        markLives(operandValues, nonLiveSet, liveSet, la).flip();
     collectNonLiveValues(nonLiveSet, successorBlock->getArguments(),
                          successorNonLive);
 
@@ -876,26 +949,29 @@ void RemoveDeadValues::runOnOperation() {
   // Tracks values eligible for erasure - complements liveness analysis to
   // identify "droppable" values.
   DenseSet<Value> deadVals;
+  // mark outgoing arguments to a public function LIVE. We also propagate
+  // liveness backward.
+  DenseSet<Value> liveVals;
 
   // Maintains a list of Ops, values, branches, etc., slated for cleanup at the
   // end of this pass.
   RDVFinalCleanupList finalCleanupList;
 
-  module->walk([&](Operation *op) {
+  module->walk<WalkOrder::PostOrder, BackwardIterator>([&](Operation *op) {
     if (auto funcOp = dyn_cast<FunctionOpInterface>(op)) {
-      processFuncOp(funcOp, module, la, deadVals, finalCleanupList);
+      processFuncOp(funcOp, module, la, deadVals, liveVals, finalCleanupList);
     } else if (auto regionBranchOp = dyn_cast<RegionBranchOpInterface>(op)) {
-      processRegionBranchOp(regionBranchOp, la, deadVals, finalCleanupList);
+      processRegionBranchOp(regionBranchOp, la, deadVals, liveVals,
+                            finalCleanupList);
     } else if (auto branchOp = dyn_cast<BranchOpInterface>(op)) {
-      processBranchOp(branchOp, la, deadVals, finalCleanupList);
+      processBranchOp(branchOp, la, deadVals, liveVals, finalCleanupList);
     } else if (op->hasTrait<::mlir::OpTrait::IsTerminator>()) {
       // Nothing to do here because this is a terminator op and it should be
       // honored with respect to its parent
     } else if (isa<CallOpInterface>(op)) {
-      // Nothing to do because this op is associated with a function op and gets
-      // cleaned when the latter is cleaned.
+      processCallOp(cast<CallOpInterface>(op), module, la, deadVals, liveVals);
     } else {
-      processSimpleOp(op, la, deadVals, finalCleanupList);
+      processSimpleOp(op, la, deadVals, liveVals, finalCleanupList);
     }
   });
 
diff --git a/mlir/test/Transforms/remove-dead-values.mlir b/mlir/test/Transforms/remove-dead-values.mlir
index 56449469dc29f..ebb76a8835ceb 100644
--- a/mlir/test/Transforms/remove-dead-values.mlir
+++ b/mlir/test/Transforms/remove-dead-values.mlir
@@ -569,6 +569,24 @@ module @return_void_with_unused_argument {
     call @fn_return_void_with_unused_argument(%arg0, %unused) : (i32, memref<4xi32>) -> ()
     return %unused : memref<4xi32>
   }
+
+  // the function signature is immutable because it is public.
+  func.func public @immutable_fn_with_unused_argument(%arg0: i32, %arg1: memref<4xf32>) -> () {
+    return
+  }
+
+  // CHECK-LABEL: func.func @main2
+  // CHECK: %[[MEM:.*]] = memref.alloc() : memref<4xf32>
+  // CHECK: %[[UNUSED:.*]] = arith.constant 0 : i32
+  // CHECK: call @immutable_fn_with_unused_argument(%[[UNUSED]], %[[MEM]]) : (i32, memref<4xf32>) -> ()
+  func.func @main2() -> () {
+    %one = arith.constant 1 : i32
+    %scalar = arith.addi %one, %one: i32
+    %mem = memref.alloc() : memref<4xf32>
+
+    call @immutable_fn_with_unused_argument(%scalar, %mem) : (i32, memref<4xf32>) -> ()
+    return
+  }
 }
 
 // -----

@navyxliu
Copy link
Contributor Author

navyxliu commented Oct 6, 2025

hi, @joker-eph
I can't reopen PR160242 because I push my branch by force. Could you take a look this new patch? thanks.

call @immutable_fn_with_unused_argument(%scalar, %mem) : (i32, memref<4xf32>) -> ()
return
}
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add two CFG examples where the blocks are listed in different order to ensure you're not sensitive to the order the blocks are in-memory.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hi, @joker-eph
I did add a testcase as you said. Then I realize that it's non-trivial to propagate liveness in RemoveDeadValues.

Here is a testcase only for RegionBranchOpinterface. From %0 is live at line 11, we need to mark %0 is live at line 2. After that, we need to mark %c1_i32 at line 4 and c0_i32 at line 7 live as well. In order words, we need to walk function @main3 preorder + backward.

     1	  func.func @main3(%arg0: i1) {
     2	    %0 = scf.if %arg0 -> (i32) {
     3	      %c1_i32 = arith.constant 1 : i32
     4	      scf.yield %c1_i32 : i32
     5	    } else {
     6	      %c0_i32 = arith.constant 0 : i32
     7	      scf.yield %c0_i32 : i32
     8	    }
     9	    %mem = memref.alloc() : memref<4xf32>
    10	
    11	    call @immutable_fn_with_unused_argument(%0, %mem) : (i32, memref<4xf32>) -> ()
    12	    return
    13	  }

I manage to fix this in propagateBackward. It pretty much redo what liveness analysis has done. TBH, I don't think this is the right way to proceed. RemoveDeadValues should keep its own single responsibility.

I take a step back and think about why we end up here. The very reason we try to propagate liveness in it because:

  1. liveness is immutable.
  2. We somehow need to update the NonLive arguments of a public function.

How about we just introduce a new pass: 'privatize-public-function' right before 'remove-dead-values'.

  1. It deploys separation of interface and implementation.
  2. If nothing changes, we preserve liveness. Otherwise, we invalidate it and let remove-dead-value recompute.

Here is a demo what this pass transforms.
i think we can waive cost model because we don't clone function body. We just create a
thin wrapper.

public void foo(int unused){...}

void main() {
arg = compute();
call foo(arg);
}
=> 
public void foo(int unused) { // interface
return __foo_impl(unused);
}

private void __foo_impl(int unused) { //implementation, new function.
... // the function body of the original foo.
}

void main() {
arg = compute();
call __foo_impl(arg);
}

This is my prototype here. do you think it's more feasible solution?
navyxliu@b73f537#diff-904855c22d662d8afbc11c40fe2906259836ba53e907e1cc899e6355358ec482

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems reasonable, but this needs to be callable from RemoveDeadValues itself (the pass can't crash itself here)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I figure out how to invalidate an analysis in analysis-manager, so I can combine them.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

update this branch. could you take a look at the new implementation?

This diff also tries to propagate liveness recursively.
@github-actions
Copy link

github-actions bot commented Oct 8, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

@navyxliu navyxliu requested a review from joker-eph October 10, 2025 07:17
@navyxliu navyxliu changed the title [MLIR][RemoveDeadValues] Mark arguments of a public function Live [MLIR][RemoveDeadValues] Privatize public function with NonLive arguments before RDV. Oct 10, 2025
@navyxliu
Copy link
Contributor Author

cc @victor-eds

Copy link
Collaborator

@joker-eph joker-eph left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should be fine, I have some minor comments inline, PTAL

if (la->getSolverConfig().isInterprocedural()) {
bool changed = false;
module->walk([&](CallOpInterface callOp) {
if (processCallOp(callOp, cast<ModuleOp>(module), *la)) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This cast does not seem safe to me: the pass isn't a modulePass right now.

Comment on lines +1012 to +1014
if (preserved) {
pa.preserve<RunLivenessAnalysis>();
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if (preserved) {
pa.preserve<RunLivenessAnalysis>();
}
if (preserved)
pa.preserve<RunLivenessAnalysis>();

Nit: no-trivial-braces in MLIR

/// Returns true if any IR changes were made, false otherwise.
static bool processCallOp(CallOpInterface callOp, ModuleOp moduleOp,
RunLivenessAnalysis &la) {
Operation *callableOp = callOp.resolveCallable();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suspect this is expensive: can we thread a SymbolTable somehow?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

mlir:core MLIR Core Infrastructure mlir

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants