Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions mlir/include/mlir/Analysis/DataFlow/LivenessAnalysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#define MLIR_ANALYSIS_DATAFLOW_LIVENESSANALYSIS_H

#include <mlir/Analysis/DataFlow/SparseAnalysis.h>
#include <mlir/Pass/AnalysisManager.h>
#include <optional>

namespace mlir::dataflow {
Expand Down Expand Up @@ -101,10 +102,19 @@ struct RunLivenessAnalysis {
RunLivenessAnalysis(Operation *op);

const Liveness *getLiveness(Value val);
// This only remarks that Liveness results are stale.
void invalidate() { valid = false; }
/// Return the configuration of the solver used for this analysis.
const DataFlowConfig &getSolverConfig() const { return solver.getConfig(); }
/// The function is called by analysis_impl::isInvalidated.
bool isInvalidated(AnalysisManager::PreservedAnalyses &) const {
return !valid;
}

private:
/// Stores the result of the liveness analysis that was run.
DataFlowSolver solver;
bool valid{true};
};

} // end namespace mlir::dataflow
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -325,5 +325,6 @@ RunLivenessAnalysis::RunLivenessAnalysis(Operation *op) {
}

const Liveness *RunLivenessAnalysis::getLiveness(Value val) {
assert(valid && "getLiveness called after invalidate");
return solver.lookupState<Liveness>(val);
}
152 changes: 147 additions & 5 deletions mlir/lib/Transforms/RemoveDeadValues.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Pass/AnalysisManager.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/FoldUtils.h"
Expand Down Expand Up @@ -869,10 +870,151 @@ struct RemoveDeadValues : public impl::RemoveDeadValuesBase<RemoveDeadValues> {
};
} // namespace

/// If the target of CallOp is a public function and at least one argument is
/// NonLive, privatize the function. Our strategy here is separation interface
/// and implementation. eg.
///
/// public void foo(int unused){...}
/// =>
/// public void foo(int unused) { // old function, interface
/// return __foo_privatized(unused);
/// }
///
/// private void __foo_privatized(int unused) { // the new private function, or
/// implementation.
/// ... // the function body of the
/// original function.
/// }
///
/// Returns true if any IR changes were made, false otherwise.
static bool processCallOp(CallOpInterface callOp, ModuleOp moduleOp,
RunLivenessAnalysis &la) {
Operation *callableOp = callOp.resolveCallable();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suspect this is expensive: can we thread a SymbolTable somehow?

auto funcOp = dyn_cast<FunctionOpInterface>(callableOp);
if (!funcOp || !funcOp.isPublic())
return false;

LDBG() << "Processing callOp " << callOp << " target is a public function: "
<< funcOp.getOperation()->getName();

// Get the list of unnecessary (non-live) arguments in `nonLiveArgs`.
SmallVector<Value> arguments(callOp.getArgOperands());
BitVector nonLiveArgs = markLives(arguments, DenseSet<Value>(), la);
nonLiveArgs = nonLiveArgs.flip();

if (nonLiveArgs.count() > 0) {
OpBuilder rewriter(moduleOp.getContext());

// Clone function and create private version
FunctionOpInterface clonedFunc = cast<FunctionOpInterface>(funcOp.clone());

// Set visibility = 'private' and a new name for the cloned function
SymbolTable::setSymbolVisibility(clonedFunc,
SymbolTable::Visibility::Private);
std::string newName = "__" + funcOp.getName().str() + "_privatized";
clonedFunc.setName(newName);

// Insert the cloned function into the module
rewriter.setInsertionPointAfter(funcOp);
rewriter.insert(clonedFunc);

// Replace ALL callsites of the original function to call the cloned
// function directly
LogicalResult result = SymbolTable::replaceAllSymbolUses(
funcOp, clonedFunc.getNameAttr(), moduleOp);

if (result.failed()) {
LDBG() << "Failed to replace all symbol uses for " << funcOp.getName();
return false;
}

LDBG() << "Redirected all callsites from " << funcOp.getName() << " to "
<< newName;

// Transform the original funcOp into a wrapper that calls the cloned
// function
Region &funcBody = funcOp.getFunctionBody();

// Clean the original function body
funcBody.dropAllReferences();
funcBody.getBlocks().clear();

// Create a new entry block for the wrapper function
Block *wrapperBlock = rewriter.createBlock(&funcBody);

// Add block arguments that match the function signature
for (Type argType : funcOp.getArgumentTypes()) {
wrapperBlock->addArgument(argType, funcOp.getLoc());
}

// Set insertion point to the new block
rewriter.setInsertionPointToStart(wrapperBlock);

// Clone the original call operation and update its callee
auto clonedCallOp = cast<CallOpInterface>(callOp->clone());
// Update the callee symbol reference to point to the new private function
auto symbolRef =
SymbolRefAttr::get(funcOp.getContext(), clonedFunc.getName());
clonedCallOp.setCalleeFromCallable(symbolRef);
// Set the call arguments to use the wrapper block's arguments
clonedCallOp->setOperands(wrapperBlock->getArguments());
rewriter.insert(clonedCallOp);

// Create return operation of the same type as the original function's
// return
Operation *returnOp = nullptr;
for (Block &block : clonedFunc.getFunctionBody()) {
if (block.getNumSuccessors() > 0)
continue;

Operation *terminator = block.getTerminator();
if (terminator && terminator->hasTrait<OpTrait::ReturnLike>()) {
returnOp = terminator;
break; // Use first return as template
}
}

if (returnOp) {
Operation *newReturnOp = returnOp->clone();
newReturnOp->setOperands(clonedCallOp->getResults());
newReturnOp->setLoc(funcOp.getLoc());
rewriter.insert(newReturnOp);
}
return true; // Changes were made
}

return false;
}

void RemoveDeadValues::runOnOperation() {
auto &la = getAnalysis<RunLivenessAnalysis>();
AnalysisManager am = getAnalysisManager();
RunLivenessAnalysis *la = &am.getAnalysis<RunLivenessAnalysis>();
Operation *module = getOperation();

// Only privatize public functions if liveness analysis is inter-procedural.
if (la->getSolverConfig().isInterprocedural()) {
bool changed = false;
module->walk([&](CallOpInterface callOp) {
if (processCallOp(callOp, cast<ModuleOp>(module), *la)) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This cast does not seem safe to me: the pass isn't a modulePass right now.

changed = true;
}
});

if (changed) {
LDBG() << "IR has changed, invalidate RunLivenessAnalysis only";
auto &pa = getPassState().preservedAnalyses;
bool preserved = pa.isPreserved<RunLivenessAnalysis>();
la->invalidate();
am.invalidate(pa);
la = &am.getAnalysis<RunLivenessAnalysis>();
// If RunLivenessAnalysis was previously preserved, preserved the updated
// results.
if (preserved) {
pa.preserve<RunLivenessAnalysis>();
}
Comment on lines +1012 to +1014
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if (preserved) {
pa.preserve<RunLivenessAnalysis>();
}
if (preserved)
pa.preserve<RunLivenessAnalysis>();

Nit: no-trivial-braces in MLIR

}
}

// Tracks values eligible for erasure - complements liveness analysis to
// identify "droppable" values.
DenseSet<Value> deadVals;
Expand All @@ -883,19 +1025,19 @@ void RemoveDeadValues::runOnOperation() {

module->walk([&](Operation *op) {
if (auto funcOp = dyn_cast<FunctionOpInterface>(op)) {
processFuncOp(funcOp, module, la, deadVals, finalCleanupList);
processFuncOp(funcOp, module, *la, deadVals, finalCleanupList);
} else if (auto regionBranchOp = dyn_cast<RegionBranchOpInterface>(op)) {
processRegionBranchOp(regionBranchOp, la, deadVals, finalCleanupList);
processRegionBranchOp(regionBranchOp, *la, deadVals, finalCleanupList);
} else if (auto branchOp = dyn_cast<BranchOpInterface>(op)) {
processBranchOp(branchOp, la, deadVals, finalCleanupList);
processBranchOp(branchOp, *la, deadVals, finalCleanupList);
} else if (op->hasTrait<::mlir::OpTrait::IsTerminator>()) {
// Nothing to do here because this is a terminator op and it should be
// honored with respect to its parent
} else if (isa<CallOpInterface>(op)) {
// Nothing to do because this op is associated with a function op and gets
// cleaned when the latter is cleaned.
} else {
processSimpleOp(op, la, deadVals, finalCleanupList);
processSimpleOp(op, *la, deadVals, finalCleanupList);
}
});

Expand Down
48 changes: 48 additions & 0 deletions mlir/test/Transforms/remove-dead-values.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -571,6 +571,54 @@ module @return_void_with_unused_argument {
}
}

// check that public functions with non-live arguments correctly.
module @public_function_with_nonlive_arguments {
// the function signature is immutable because it is public.
func.func public @public_fn_with_unused_argument(%unused: i32) -> () {
return
}
// CHECK-LABEL: func.func @main
// CHECK: call @__public_fn_with_unused_argument_privatized() : () -> ()
func.func @main() -> () {
%zero = arith.constant 0 : i32
call @public_fn_with_unused_argument(%zero) : (i32) -> ()
return
}

// CHECK-LABEL: func.func @main2
// CHECK: call @__public_fn_with_unused_argument_privatized() : () -> ()
func.func @main2(%arg0: i1) {
%0 = scf.if %arg0 -> (i32) {
%c1_i32 = arith.constant 1 : i32
scf.yield %c1_i32 : i32
} else {
%c0_i32 = arith.constant 0 : i32
scf.yield %c0_i32 : i32
}

call @public_fn_with_unused_argument(%0) : (i32) -> ()
return
}

func.func public @fn_return_multiple(%arg0: i32) -> (i32, i32, i32) {
%one = arith.constant 1 : i32
%two = arith.constant 2 : i32
%three = arith.constant 4 : i32

return %one, %two, %three: i32, i32, i32
}

// CHECK-LABEL: func.func @main3
// CHECK: call @__fn_return_multiple_privatized() : () -> (i32, i32, i32)
func.func @main3(%arg: i32) -> () {
%one = arith.constant 1 : i32
%scalar = arith.addi %arg, %one: i32

call @fn_return_multiple(%scalar) : (i32) -> (i32, i32, i32)
return
}
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add two CFG examples where the blocks are listed in different order to ensure you're not sensitive to the order the blocks are in-memory.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hi, @joker-eph
I did add a testcase as you said. Then I realize that it's non-trivial to propagate liveness in RemoveDeadValues.

Here is a testcase only for RegionBranchOpinterface. From %0 is live at line 11, we need to mark %0 is live at line 2. After that, we need to mark %c1_i32 at line 4 and c0_i32 at line 7 live as well. In order words, we need to walk function @main3 preorder + backward.

     1	  func.func @main3(%arg0: i1) {
     2	    %0 = scf.if %arg0 -> (i32) {
     3	      %c1_i32 = arith.constant 1 : i32
     4	      scf.yield %c1_i32 : i32
     5	    } else {
     6	      %c0_i32 = arith.constant 0 : i32
     7	      scf.yield %c0_i32 : i32
     8	    }
     9	    %mem = memref.alloc() : memref<4xf32>
    10	
    11	    call @immutable_fn_with_unused_argument(%0, %mem) : (i32, memref<4xf32>) -> ()
    12	    return
    13	  }

I manage to fix this in propagateBackward. It pretty much redo what liveness analysis has done. TBH, I don't think this is the right way to proceed. RemoveDeadValues should keep its own single responsibility.

I take a step back and think about why we end up here. The very reason we try to propagate liveness in it because:

  1. liveness is immutable.
  2. We somehow need to update the NonLive arguments of a public function.

How about we just introduce a new pass: 'privatize-public-function' right before 'remove-dead-values'.

  1. It deploys separation of interface and implementation.
  2. If nothing changes, we preserve liveness. Otherwise, we invalidate it and let remove-dead-value recompute.

Here is a demo what this pass transforms.
i think we can waive cost model because we don't clone function body. We just create a
thin wrapper.

public void foo(int unused){...}

void main() {
arg = compute();
call foo(arg);
}
=> 
public void foo(int unused) { // interface
return __foo_impl(unused);
}

private void __foo_impl(int unused) { //implementation, new function.
... // the function body of the original foo.
}

void main() {
arg = compute();
call __foo_impl(arg);
}

This is my prototype here. do you think it's more feasible solution?
navyxliu@b73f537#diff-904855c22d662d8afbc11c40fe2906259836ba53e907e1cc899e6355358ec482

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems reasonable, but this needs to be callable from RemoveDeadValues itself (the pass can't crash itself here)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I figure out how to invalidate an analysis in analysis-manager, so I can combine them.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

update this branch. could you take a look at the new implementation?


// -----

// CHECK-LABEL: module @dynamically_unreachable
Expand Down