Skip to content
Closed
3 changes: 3 additions & 0 deletions mlir/include/mlir/Analysis/DataFlow/LivenessAnalysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,9 @@ struct RunLivenessAnalysis {

const Liveness *getLiveness(Value val);

/// Return the configuration of the solver used for this analysis.
const DataFlowConfig &getSolverConfig() const { return solver.getConfig(); }

private:
/// Stores the result of the liveness analysis that was run.
DataFlowSolver solver;
Expand Down
14 changes: 14 additions & 0 deletions mlir/include/mlir/IR/Visitors.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,20 @@ struct ForwardIterator {
}
};

/// This iterator enumerates the elements in "backward" order.
struct BackwardIterator {
template <typename T>
static auto makeIterable(T &range) {
if constexpr (std::is_same<T, Operation>()) {
/// Make operations iterable: return the list of regions.
return llvm::reverse(range.getRegions());
} else {
/// Regions and block are already iterable.
return llvm::reverse(range);
}
}
};

/// A utility class to encode the current walk stage for "generic" walkers.
/// When walking an operation, we can either choose a Pre/Post order walker
/// which invokes the callback on an operation before/after all its attached
Expand Down
114 changes: 95 additions & 19 deletions mlir/lib/Transforms/RemoveDeadValues.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@

#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
#include "mlir/Analysis/DataFlow/LivenessAnalysis.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Dialect.h"
Expand Down Expand Up @@ -118,8 +119,13 @@ struct RDVFinalCleanupList {
/// Return true iff at least one value in `values` is live, given the liveness
/// information in `la`.
static bool hasLive(ValueRange values, const DenseSet<Value> &nonLiveSet,
RunLivenessAnalysis &la) {
const DenseSet<Value> &liveSet, RunLivenessAnalysis &la) {
for (Value value : values) {
if (liveSet.contains(value)) {
LDBG() << "Value " << value << " is marked live by CallOp";
return true;
}

if (nonLiveSet.contains(value)) {
LDBG() << "Value " << value << " is already marked non-live (dead)";
continue;
Expand All @@ -144,6 +150,7 @@ static bool hasLive(ValueRange values, const DenseSet<Value> &nonLiveSet,
/// Return a BitVector of size `values.size()` where its i-th bit is 1 iff the
/// i-th value in `values` is live, given the liveness information in `la`.
static BitVector markLives(ValueRange values, const DenseSet<Value> &nonLiveSet,
const DenseSet<Value> &liveSet,
RunLivenessAnalysis &la) {
BitVector lives(values.size(), true);

Expand All @@ -154,7 +161,9 @@ static BitVector markLives(ValueRange values, const DenseSet<Value> &nonLiveSet,
<< " is already marked non-live (dead) at index " << index;
continue;
}

if (liveSet.contains(value)) {
continue;
}
const Liveness *liveness = la.getLiveness(value);
// It is important to note that when `liveness` is null, we can't tell if
// `value` is live or not. So, the safe option is to consider it live. Also,
Expand Down Expand Up @@ -259,8 +268,19 @@ static SmallVector<OpOperand *> operandsToOpOperands(OperandRange operands) {
/// - Return-like
static void processSimpleOp(Operation *op, RunLivenessAnalysis &la,
DenseSet<Value> &nonLiveSet,
RDVFinalCleanupList &cl) {
if (!isMemoryEffectFree(op) || hasLive(op->getResults(), nonLiveSet, la)) {
DenseSet<Value> &liveSet, RDVFinalCleanupList &cl) {
for (Value val : op->getResults()) {
if (liveSet.contains(val)) {
LDBG() << "Simple op is used by a public function, "
"preserving it: "
<< OpWithFlags(op, OpPrintingFlags().skipRegions());
liveSet.insert_range(op->getOperands());
return;
}
}

if (!isMemoryEffectFree(op) ||
hasLive(op->getResults(), nonLiveSet, liveSet, la)) {
LDBG() << "Simple op is not memory effect free or has live results, "
"preserving it: "
<< OpWithFlags(op, OpPrintingFlags().skipRegions());
Expand Down Expand Up @@ -288,7 +308,7 @@ static void processSimpleOp(Operation *op, RunLivenessAnalysis &la,
/// (6) Marking all its results as non-live values.
static void processFuncOp(FunctionOpInterface funcOp, Operation *module,
RunLivenessAnalysis &la, DenseSet<Value> &nonLiveSet,
RDVFinalCleanupList &cl) {
DenseSet<Value> &liveSet, RDVFinalCleanupList &cl) {
LDBG() << "Processing function op: "
<< OpWithFlags(funcOp, OpPrintingFlags().skipRegions());
if (funcOp.isPublic() || funcOp.isExternal()) {
Expand All @@ -299,7 +319,7 @@ static void processFuncOp(FunctionOpInterface funcOp, Operation *module,

// Get the list of unnecessary (non-live) arguments in `nonLiveArgs`.
SmallVector<Value> arguments(funcOp.getArguments());
BitVector nonLiveArgs = markLives(arguments, nonLiveSet, la);
BitVector nonLiveArgs = markLives(arguments, nonLiveSet, liveSet, la);
nonLiveArgs = nonLiveArgs.flip();

// Do (1).
Expand Down Expand Up @@ -352,7 +372,8 @@ static void processFuncOp(FunctionOpInterface funcOp, Operation *module,
for (SymbolTable::SymbolUse use : uses) {
Operation *callOp = use.getUser();
assert(isa<CallOpInterface>(callOp) && "expected a call-like user");
BitVector liveCallRets = markLives(callOp->getResults(), nonLiveSet, la);
BitVector liveCallRets =
markLives(callOp->getResults(), nonLiveSet, liveSet, la);
nonLiveRets &= liveCallRets.flip();
}

Expand All @@ -379,6 +400,56 @@ static void processFuncOp(FunctionOpInterface funcOp, Operation *module,
}
}

// Create a cheaper value with the same type of oldVal in front of CallOp.
static Value createDummyArgument(CallOpInterface callOp, Value oldVal) {
OpBuilder builder(callOp.getOperation());
Type type = oldVal.getType();

// Create zero constant for any supported type
if (TypedAttr zeroAttr = builder.getZeroAttr(type)) {
return builder.create<arith::ConstantOp>(oldVal.getLoc(), type, zeroAttr);
}
return {};
}

static void processCallOp(CallOpInterface callOp, Operation *module,
RunLivenessAnalysis &la, DenseSet<Value> &nonLiveSet,
DenseSet<Value> &liveSet) {
if (!la.getSolverConfig().isInterprocedural())
return;

Operation *callableOp = callOp.resolveCallable();
auto funcOp = dyn_cast<FunctionOpInterface>(callableOp);
if (!funcOp || !funcOp.isPublic()) {
return;
}

LDBG() << "processCallOp to a public function: " << funcOp.getName();
// Get the list of unnecessary (non-live) arguments in `nonLiveArgs`.
SmallVector<Value> arguments(funcOp.getArguments());
BitVector nonLiveArgs = markLives(arguments, nonLiveSet, liveSet, la);
nonLiveArgs = nonLiveArgs.flip();

if (nonLiveArgs.count() > 0) {
LDBG() << funcOp.getName() << " contains NonLive arguments";
// The number of operands in the call op may not match the number of
// arguments in the func op.
SmallVector<OpOperand *> callOpOperands =
operandsToOpOperands(callOp.getArgOperands());

for (int index : nonLiveArgs.set_bits()) {
OpOperand *operand = callOpOperands[index];
Value oldVal = operand->get();
if (Value dummy = createDummyArgument(callOp, oldVal)) {
callOp->setOperand(operand->getOperandNumber(), dummy);
nonLiveSet.insert(oldVal);
} else {
liveSet.insert(oldVal);
}
}
}
}

/// Process a region branch operation `regionBranchOp` using the liveness
/// information in `la`. The processing involves two scenarios:
///
Expand Down Expand Up @@ -411,12 +482,14 @@ static void processFuncOp(FunctionOpInterface funcOp, Operation *module,
static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
RunLivenessAnalysis &la,
DenseSet<Value> &nonLiveSet,
DenseSet<Value> &liveSet,
RDVFinalCleanupList &cl) {
LDBG() << "Processing region branch op: "
<< OpWithFlags(regionBranchOp, OpPrintingFlags().skipRegions());
// Mark live results of `regionBranchOp` in `liveResults`.
auto markLiveResults = [&](BitVector &liveResults) {
liveResults = markLives(regionBranchOp->getResults(), nonLiveSet, la);
liveResults =
markLives(regionBranchOp->getResults(), nonLiveSet, liveSet, la);
};

// Mark live arguments in the regions of `regionBranchOp` in `liveArgs`.
Expand All @@ -425,7 +498,7 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
if (region.empty())
continue;
SmallVector<Value> arguments(region.front().getArguments());
BitVector regionLiveArgs = markLives(arguments, nonLiveSet, la);
BitVector regionLiveArgs = markLives(arguments, nonLiveSet, liveSet, la);
liveArgs[&region] = regionLiveArgs;
}
};
Expand Down Expand Up @@ -619,7 +692,7 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
// attributed to something else.
// Do (1') and (2').
if (isMemoryEffectFree(regionBranchOp.getOperation()) &&
!hasLive(regionBranchOp->getResults(), nonLiveSet, la)) {
!hasLive(regionBranchOp->getResults(), nonLiveSet, liveSet, la)) {
cl.operations.push_back(regionBranchOp.getOperation());
return;
}
Expand Down Expand Up @@ -698,7 +771,7 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,

static void processBranchOp(BranchOpInterface branchOp, RunLivenessAnalysis &la,
DenseSet<Value> &nonLiveSet,
RDVFinalCleanupList &cl) {
DenseSet<Value> &liveSet, RDVFinalCleanupList &cl) {
LDBG() << "Processing branch op: " << *branchOp;
unsigned numSuccessors = branchOp->getNumSuccessors();

Expand All @@ -716,7 +789,7 @@ static void processBranchOp(BranchOpInterface branchOp, RunLivenessAnalysis &la,

// Do (2)
BitVector successorNonLive =
markLives(operandValues, nonLiveSet, la).flip();
markLives(operandValues, nonLiveSet, liveSet, la).flip();
collectNonLiveValues(nonLiveSet, successorBlock->getArguments(),
successorNonLive);

Expand Down Expand Up @@ -876,26 +949,29 @@ void RemoveDeadValues::runOnOperation() {
// Tracks values eligible for erasure - complements liveness analysis to
// identify "droppable" values.
DenseSet<Value> deadVals;
// mark outgoing arguments to a public function LIVE. We also propagate
// liveness backward.
DenseSet<Value> liveVals;

// Maintains a list of Ops, values, branches, etc., slated for cleanup at the
// end of this pass.
RDVFinalCleanupList finalCleanupList;

module->walk([&](Operation *op) {
module->walk<WalkOrder::PostOrder, BackwardIterator>([&](Operation *op) {
if (auto funcOp = dyn_cast<FunctionOpInterface>(op)) {
processFuncOp(funcOp, module, la, deadVals, finalCleanupList);
processFuncOp(funcOp, module, la, deadVals, liveVals, finalCleanupList);
} else if (auto regionBranchOp = dyn_cast<RegionBranchOpInterface>(op)) {
processRegionBranchOp(regionBranchOp, la, deadVals, finalCleanupList);
processRegionBranchOp(regionBranchOp, la, deadVals, liveVals,
finalCleanupList);
} else if (auto branchOp = dyn_cast<BranchOpInterface>(op)) {
processBranchOp(branchOp, la, deadVals, finalCleanupList);
processBranchOp(branchOp, la, deadVals, liveVals, finalCleanupList);
} else if (op->hasTrait<::mlir::OpTrait::IsTerminator>()) {
// Nothing to do here because this is a terminator op and it should be
// honored with respect to its parent
} else if (isa<CallOpInterface>(op)) {
// Nothing to do because this op is associated with a function op and gets
// cleaned when the latter is cleaned.
processCallOp(cast<CallOpInterface>(op), module, la, deadVals, liveVals);
} else {
processSimpleOp(op, la, deadVals, finalCleanupList);
processSimpleOp(op, la, deadVals, liveVals, finalCleanupList);
}
});

Expand Down
18 changes: 18 additions & 0 deletions mlir/test/Transforms/remove-dead-values.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -569,6 +569,24 @@ module @return_void_with_unused_argument {
call @fn_return_void_with_unused_argument(%arg0, %unused) : (i32, memref<4xi32>) -> ()
return %unused : memref<4xi32>
}

// the function signature is immutable because it is public.
func.func public @immutable_fn_with_unused_argument(%arg0: i32, %arg1: memref<4xf32>) -> () {
return
}

// CHECK-LABEL: func.func @main2
// CHECK: %[[MEM:.*]] = memref.alloc() : memref<4xf32>
// CHECK: %[[UNUSED:.*]] = arith.constant 0 : i32
// CHECK: call @immutable_fn_with_unused_argument(%[[UNUSED]], %[[MEM]]) : (i32, memref<4xf32>) -> ()
func.func @main2() -> () {
%one = arith.constant 1 : i32
%scalar = arith.addi %one, %one: i32
%mem = memref.alloc() : memref<4xf32>

call @immutable_fn_with_unused_argument(%scalar, %mem) : (i32, memref<4xf32>) -> ()
return
}
}

// -----
Expand Down