Skip to content
Closed
10 changes: 10 additions & 0 deletions mlir/include/mlir/Analysis/DataFlow/LivenessAnalysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#define MLIR_ANALYSIS_DATAFLOW_LIVENESSANALYSIS_H

#include <mlir/Analysis/DataFlow/SparseAnalysis.h>
#include <mlir/Pass/AnalysisManager.h>
#include <optional>

namespace mlir::dataflow {
Expand Down Expand Up @@ -101,10 +102,19 @@ struct RunLivenessAnalysis {
RunLivenessAnalysis(Operation *op);

const Liveness *getLiveness(Value val);
// This only remarks that Liveness results are stale.
void invalidate() { valid = false; }
/// Return the configuration of the solver used for this analysis.
const DataFlowConfig &getSolverConfig() const { return solver.getConfig(); }
/// The function is called by analysis_impl::isInvalidated.
bool isInvalidated(AnalysisManager::PreservedAnalyses &) const {
return !valid;
}

private:
/// Stores the result of the liveness analysis that was run.
DataFlowSolver solver;
bool valid{true};
};

} // end namespace mlir::dataflow
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -356,5 +356,6 @@ RunLivenessAnalysis::RunLivenessAnalysis(Operation *op) {
}

const Liveness *RunLivenessAnalysis::getLiveness(Value val) {
assert(valid && "getLiveness called after invalidate");
return solver.lookupState<Liveness>(val);
}
161 changes: 156 additions & 5 deletions mlir/lib/Transforms/RemoveDeadValues.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,15 @@
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Pass/AnalysisManager.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/FoldUtils.h"
#include "mlir/Transforms/Passes.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/DebugLog.h"
#include "llvm/Support/LogicalResult.h"
#include <cassert>
#include <cstddef>
#include <memory>
Expand Down Expand Up @@ -869,10 +871,159 @@ struct RemoveDeadValues : public impl::RemoveDeadValuesBase<RemoveDeadValues> {
};
} // namespace

/// If the target of CallOp is a public function and at least one argument is
/// NonLive, privatize the function. Our strategy here is separation interface
/// and implementation. eg.
///
/// public void foo(int unused){...}
/// =>
/// public void foo(int unused) { // old function, interface
/// return __foo_privatized(unused);
/// }
///
/// private void __foo_privatized(int unused) { // the new private function, or
/// implementation.
/// ... // the function body of the
/// original function.
/// }
///
/// changed = true if any IR changes were made.
///
/// Cloning has to be Interface-based because downstream projects may use their
/// own func/call/return ops.
static LogicalResult processCallOp(CallOpInterface callOp, ModuleOp moduleOp,
RunLivenessAnalysis &la,
SymbolTableCollection *symbolTable,
bool &changed) {
Operation *callableOp = callOp.resolveCallableInTable(symbolTable);
auto funcOp = dyn_cast_or_null<FunctionOpInterface>(callableOp);
if (!funcOp || !funcOp.isPublic())
return LogicalResult::success();

LDBG() << "Processing callOp " << callOp << " target is a public function: "
<< funcOp.getOperation()->getName();

// Get the list of unnecessary (non-live) arguments in `nonLiveArgs`.
SmallVector<Value> arguments(callOp.getArgOperands());
BitVector nonLiveArgs = markLives(arguments, DenseSet<Value>(), la);
nonLiveArgs = nonLiveArgs.flip();

if (nonLiveArgs.count() > 0) {
OpBuilder rewriter(moduleOp.getContext());

// Clone function and create private version
FunctionOpInterface clonedFunc =
cast<FunctionOpInterface>(funcOp->cloneWithoutRegions());

// Set visibility = 'private' and a new name for the cloned function
SymbolTable::setSymbolVisibility(clonedFunc,
SymbolTable::Visibility::Private);
std::string newName = "__" + funcOp.getName().str() + "_privatized";
clonedFunc.setName(newName);

// Insert the cloned function into the module
rewriter.setInsertionPointAfter(funcOp);
rewriter.insert(clonedFunc);

// Replace ALL callsites of the original function to call the cloned
// function directly
LogicalResult result = SymbolTable::replaceAllSymbolUses(
funcOp, clonedFunc.getNameAttr(), moduleOp);

if (result.failed()) {
callOp.emitError(
"Failed to replace all symbol uses when privatizing function ")
<< funcOp.getName();
return result;
}
LDBG() << "Redirected all callsites from " << funcOp.getName() << " to "
<< newName;

Region &clonedFuncBody = clonedFunc.getFunctionBody();
// Move the body from funcOp to clonedFunc
clonedFuncBody.takeBody(funcOp.getFunctionBody());

// Create a new entry block for the wrapper function in funcOp
Block *wrapperBlock = rewriter.createBlock(&funcOp.getFunctionBody());

// Add block arguments that match the function signature
for (Type argType : funcOp.getArgumentTypes()) {
wrapperBlock->addArgument(argType, funcOp.getLoc());
}

// Set insertion point to the new block
rewriter.setInsertionPointToStart(wrapperBlock);

// Clone the original call operation and update its callee
auto clonedCallOp = cast<CallOpInterface>(callOp->clone());
// Update the callee symbol reference to point to the new private function
auto symbolRef =
SymbolRefAttr::get(funcOp.getContext(), clonedFunc.getName());
clonedCallOp.setCalleeFromCallable(symbolRef);
// Set the call arguments to use the wrapper block's arguments
clonedCallOp->setOperands(wrapperBlock->getArguments());
rewriter.insert(clonedCallOp);

// Create return operation of the same type as the original function's
// returnOp.
Operation *returnOp = nullptr;
for (Block &block : clonedFuncBody) {
if (block.getNumSuccessors() > 0)
continue;

Operation *terminator = block.getTerminator();
if (terminator && terminator->hasTrait<OpTrait::ReturnLike>()) {
returnOp = terminator;
break; // Use first return as template
}
}

if (returnOp) {
Operation *newReturnOp = returnOp->clone();
newReturnOp->setOperands(clonedCallOp->getResults());
newReturnOp->setLoc(returnOp->getLoc());
rewriter.insert(newReturnOp);
}
changed = true; // Changes were made
}

return LogicalResult::success();
}

void RemoveDeadValues::runOnOperation() {
auto &la = getAnalysis<RunLivenessAnalysis>();
AnalysisManager am = getAnalysisManager();
RunLivenessAnalysis *la = &am.getAnalysis<RunLivenessAnalysis>();
Operation *module = getOperation();

// In a module, only privatize public functions if liveness analysis is
// inter-procedural.
if (la->getSolverConfig().isInterprocedural() && isa<ModuleOp>(module)) {
bool changed = false;
SymbolTableCollection symbolTable;
WalkResult walkResult =
module->walk([&](CallOpInterface callOp) -> WalkResult {
return processCallOp(callOp, cast<ModuleOp>(module), *la,
&symbolTable, changed);
});
if (walkResult.wasInterrupted()) {
signalPassFailure();
return;
}

if (changed) {
LDBG() << "IR has changed, invalidate RunLivenessAnalysis only";
auto &pa = getPassState().preservedAnalyses;
bool preserved = pa.isPreserved<RunLivenessAnalysis>();
la->invalidate();
am.invalidate(pa);
la = &am.getAnalysis<RunLivenessAnalysis>();
// If RunLivenessAnalysis was previously preserved, preserved the updated
// results.
if (preserved)
pa.preserve<RunLivenessAnalysis>();
}
}

// Tracks values eligible for erasure - complements liveness analysis to
// identify "droppable" values.
DenseSet<Value> deadVals;
Expand All @@ -883,19 +1034,19 @@ void RemoveDeadValues::runOnOperation() {

module->walk([&](Operation *op) {
if (auto funcOp = dyn_cast<FunctionOpInterface>(op)) {
processFuncOp(funcOp, module, la, deadVals, finalCleanupList);
processFuncOp(funcOp, module, *la, deadVals, finalCleanupList);
} else if (auto regionBranchOp = dyn_cast<RegionBranchOpInterface>(op)) {
processRegionBranchOp(regionBranchOp, la, deadVals, finalCleanupList);
processRegionBranchOp(regionBranchOp, *la, deadVals, finalCleanupList);
} else if (auto branchOp = dyn_cast<BranchOpInterface>(op)) {
processBranchOp(branchOp, la, deadVals, finalCleanupList);
processBranchOp(branchOp, *la, deadVals, finalCleanupList);
} else if (op->hasTrait<::mlir::OpTrait::IsTerminator>()) {
// Nothing to do here because this is a terminator op and it should be
// honored with respect to its parent
} else if (isa<CallOpInterface>(op)) {
// Nothing to do because this op is associated with a function op and gets
// cleaned when the latter is cleaned.
} else {
processSimpleOp(op, la, deadVals, finalCleanupList);
processSimpleOp(op, *la, deadVals, finalCleanupList);
}
});

Expand Down
48 changes: 48 additions & 0 deletions mlir/test/Transforms/remove-dead-values.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -571,6 +571,54 @@ module @return_void_with_unused_argument {
}
}

// check that public functions with non-live arguments correctly.
module @public_function_with_nonlive_arguments {
// the function signature is immutable because it is public.
func.func public @public_fn_with_unused_argument(%unused: i32) -> () {
return
}
// CHECK-LABEL: func.func @main
// CHECK: call @__public_fn_with_unused_argument_privatized() : () -> ()
func.func @main() -> () {
%zero = arith.constant 0 : i32
call @public_fn_with_unused_argument(%zero) : (i32) -> ()
return
}

// CHECK-LABEL: func.func @main2
// CHECK: call @__public_fn_with_unused_argument_privatized() : () -> ()
func.func @main2(%arg0: i1) {
%0 = scf.if %arg0 -> (i32) {
%c1_i32 = arith.constant 1 : i32
scf.yield %c1_i32 : i32
} else {
%c0_i32 = arith.constant 0 : i32
scf.yield %c0_i32 : i32
}

call @public_fn_with_unused_argument(%0) : (i32) -> ()
return
}

func.func public @fn_return_multiple(%arg0: i32) -> (i32, i32, i32) {
%one = arith.constant 1 : i32
%two = arith.constant 2 : i32
%three = arith.constant 4 : i32

return %one, %two, %three: i32, i32, i32
}

// CHECK-LABEL: func.func @main3
// CHECK: call @__fn_return_multiple_privatized() : () -> (i32, i32, i32)
func.func @main3(%arg: i32) -> () {
%one = arith.constant 1 : i32
%scalar = arith.addi %arg, %one: i32

call @fn_return_multiple(%scalar) : (i32) -> (i32, i32, i32)
return
}
}

// -----

// CHECK-LABEL: module @dynamically_unreachable
Expand Down