Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 28 additions & 1 deletion mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,34 @@ RunLivenessAnalysis::RunLivenessAnalysis(Operation *op) {
solver.load<LivenessAnalysis>(symbolTable);
LDBG() << "Initializing and running solver";
(void)solver.initializeAndRun(op);
LDBG() << "RunLivenessAnalysis initialized for op: " << op->getName();
LDBG() << "RunLivenessAnalysis initialized for op: " << op->getName()
<< " check on unreachable code now:";
// The framework doesn't visit operations in dead blocks, so we need to
// explicitly mark them as dead.
op->walk([&](Operation *op) {
if (op->getNumResults() == 0)
return;
for (auto result : llvm::enumerate(op->getResults())) {
if (getLiveness(result.value()))
continue;
LDBG() << "Result: " << result.index() << " of "
<< OpWithFlags(op, OpPrintingFlags().skipRegions())
<< " has no liveness info (unreachable), mark dead";
solver.getOrCreateState<Liveness>(result.value());
}
for (auto &region : op->getRegions()) {
for (auto &block : region) {
for (auto blockArg : llvm::enumerate(block.getArguments())) {
if (getLiveness(blockArg.value()))
continue;
LDBG() << "Block argument: " << blockArg.index() << " of "
<< OpWithFlags(op, OpPrintingFlags().skipRegions())
<< " has no liveness info, mark dead";
solver.getOrCreateState<Liveness>(blockArg.value());
}
}
}
});
}

const Liveness *RunLivenessAnalysis::getLiveness(Value val) {
Expand Down
46 changes: 41 additions & 5 deletions mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,15 @@
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/DebugLog.h"
#include <cassert>
#include <optional>

using namespace mlir;
using namespace mlir::dataflow;

#define DEBUG_TYPE "dataflow"

//===----------------------------------------------------------------------===//
// AbstractSparseLattice
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -64,22 +67,36 @@ AbstractSparseForwardDataFlowAnalysis::initialize(Operation *top) {

LogicalResult
AbstractSparseForwardDataFlowAnalysis::initializeRecursively(Operation *op) {
LDBG() << "Initializing recursively for operation: " << op->getName();

// Initialize the analysis by visiting every owner of an SSA value (all
// operations and blocks).
if (failed(visitOperation(op)))
if (failed(visitOperation(op))) {
LDBG() << "Failed to visit operation: " << op->getName();
return failure();
}

for (Region &region : op->getRegions()) {
LDBG() << "Processing region with " << region.getBlocks().size()
<< " blocks";
for (Block &block : region) {
LDBG() << "Processing block with " << block.getNumArguments()
<< " arguments";
getOrCreate<Executable>(getProgramPointBefore(&block))
->blockContentSubscribe(this);
visitBlock(&block);
for (Operation &op : block)
if (failed(initializeRecursively(&op)))
for (Operation &op : block) {
LDBG() << "Recursively initializing nested operation: " << op.getName();
if (failed(initializeRecursively(&op))) {
LDBG() << "Failed to initialize nested operation: " << op.getName();
return failure();
}
}
}
}

LDBG() << "Successfully completed recursive initialization for operation: "
<< op->getName();
return success();
}

Expand Down Expand Up @@ -409,11 +426,20 @@ static MutableArrayRef<OpOperand> operandsToOpOperands(OperandRange &operands) {

LogicalResult
AbstractSparseBackwardDataFlowAnalysis::visitOperation(Operation *op) {
LDBG() << "Visiting operation: " << op->getName() << " with "
<< op->getNumOperands() << " operands and " << op->getNumResults()
<< " results";

// If we're in a dead block, bail out.
if (op->getBlock() != nullptr &&
!getOrCreate<Executable>(getProgramPointBefore(op->getBlock()))->isLive())
!getOrCreate<Executable>(getProgramPointBefore(op->getBlock()))
->isLive()) {
LDBG() << "Operation is in dead block, bailing out";
return success();
}

LDBG() << "Creating lattice elements for " << op->getNumOperands()
<< " operands and " << op->getNumResults() << " results";
SmallVector<AbstractSparseLattice *> operandLattices =
getLatticeElements(op->getOperands());
SmallVector<const AbstractSparseLattice *> resultLattices =
Expand All @@ -422,11 +448,15 @@ AbstractSparseBackwardDataFlowAnalysis::visitOperation(Operation *op) {
// Block arguments of region branch operations flow back into the operands
// of the parent op
if (auto branch = dyn_cast<RegionBranchOpInterface>(op)) {
LDBG() << "Processing RegionBranchOpInterface operation";
visitRegionSuccessors(branch, operandLattices);
return success();
}

if (auto branch = dyn_cast<BranchOpInterface>(op)) {
LDBG() << "Processing BranchOpInterface operation with "
<< op->getNumSuccessors() << " successors";

// Block arguments of successor blocks flow back into our operands.

// We remember all operands not forwarded to any block in a BitVector.
Expand Down Expand Up @@ -463,6 +493,7 @@ AbstractSparseBackwardDataFlowAnalysis::visitOperation(Operation *op) {
// For function calls, connect the arguments of the entry blocks to the
// operands of the call op that are forwarded to these arguments.
if (auto call = dyn_cast<CallOpInterface>(op)) {
LDBG() << "Processing CallOpInterface operation";
Operation *callableOp = call.resolveCallableInTable(&symbolTable);
if (auto callable = dyn_cast_or_null<CallableOpInterface>(callableOp)) {
// Not all operands of a call op forward to arguments. Such operands are
Expand Down Expand Up @@ -513,19 +544,24 @@ AbstractSparseBackwardDataFlowAnalysis::visitOperation(Operation *op) {
// of this op itself and the operands of the terminators of the regions of
// this op.
if (auto terminator = dyn_cast<RegionBranchTerminatorOpInterface>(op)) {
LDBG() << "Processing RegionBranchTerminatorOpInterface operation";
if (auto branch = dyn_cast<RegionBranchOpInterface>(op->getParentOp())) {
visitRegionSuccessorsFromTerminator(terminator, branch);
return success();
}
}

if (op->hasTrait<OpTrait::ReturnLike>()) {
LDBG() << "Processing ReturnLike operation";
// Going backwards, the operands of the return are derived from the
// results of all CallOps calling this CallableOp.
if (auto callable = dyn_cast<CallableOpInterface>(op->getParentOp()))
if (auto callable = dyn_cast<CallableOpInterface>(op->getParentOp())) {
LDBG() << "Callable parent found, visiting callable operation";
return visitCallableOperation(op, callable, operandLattices);
}
}

LDBG() << "Using default visitOperationImpl for operation: " << op->getName();
return visitOperationImpl(op, operandLattices, resultLattices);
}

Expand Down
46 changes: 40 additions & 6 deletions mlir/lib/Transforms/RemoveDeadValues.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -258,18 +258,17 @@ static SmallVector<OpOperand *> operandsToOpOperands(OperandRange operands) {
static void processSimpleOp(Operation *op, RunLivenessAnalysis &la,
DenseSet<Value> &nonLiveSet,
RDVFinalCleanupList &cl) {
LDBG() << "Processing simple op: " << *op;
if (!isMemoryEffectFree(op) || hasLive(op->getResults(), nonLiveSet, la)) {
LDBG()
<< "Simple op is not memory effect free or has live results, skipping: "
<< *op;
LDBG() << "Simple op is not memory effect free or has live results, "
"preserving it: "
<< OpWithFlags(op, OpPrintingFlags().skipRegions());
return;
}

LDBG()
<< "Simple op has all dead results and is memory effect free, scheduling "
"for removal: "
<< *op;
<< OpWithFlags(op, OpPrintingFlags().skipRegions());
cl.operations.push_back(op);
collectNonLiveValues(nonLiveSet, op->getResults(),
BitVector(op->getNumResults(), true));
Expand Down Expand Up @@ -728,19 +727,31 @@ static void processBranchOp(BranchOpInterface branchOp, RunLivenessAnalysis &la,
/// Removes dead values collected in RDVFinalCleanupList.
/// To be run once when all dead values have been collected.
static void cleanUpDeadVals(RDVFinalCleanupList &list) {
LDBG() << "Starting cleanup of dead values...";

// 1. Operations
LDBG() << "Cleaning up " << list.operations.size() << " operations";
for (auto &op : list.operations) {
LDBG() << "Erasing operation: "
<< OpWithFlags(op, OpPrintingFlags().skipRegions());
op->dropAllUses();
op->erase();
}

// 2. Values
LDBG() << "Cleaning up " << list.values.size() << " values";
for (auto &v : list.values) {
LDBG() << "Dropping all uses of value: " << v;
v.dropAllUses();
}

// 3. Functions
LDBG() << "Cleaning up " << list.functions.size() << " functions";
for (auto &f : list.functions) {
LDBG() << "Cleaning up function: " << f.funcOp.getOperation()->getName();
LDBG() << " Erasing " << f.nonLiveArgs.count() << " non-live arguments";
LDBG() << " Erasing " << f.nonLiveRets.count()
<< " non-live return values";
// Some functions may not allow erasing arguments or results. These calls
// return failure in such cases without modifying the function, so it's okay
// to proceed.
Expand All @@ -749,44 +760,67 @@ static void cleanUpDeadVals(RDVFinalCleanupList &list) {
}

// 4. Operands
LDBG() << "Cleaning up " << list.operands.size() << " operand lists";
for (OperationToCleanup &o : list.operands) {
if (o.op->getNumOperands() > 0)
if (o.op->getNumOperands() > 0) {
LDBG() << "Erasing " << o.nonLive.count()
<< " non-live operands from operation: "
<< OpWithFlags(o.op, OpPrintingFlags().skipRegions());
o.op->eraseOperands(o.nonLive);
}
}

// 5. Results
LDBG() << "Cleaning up " << list.results.size() << " result lists";
for (auto &r : list.results) {
LDBG() << "Erasing " << r.nonLive.count()
<< " non-live results from operation: "
<< OpWithFlags(r.op, OpPrintingFlags().skipRegions());
dropUsesAndEraseResults(r.op, r.nonLive);
}

// 6. Blocks
LDBG() << "Cleaning up " << list.blocks.size() << " block argument lists";
for (auto &b : list.blocks) {
// blocks that are accessed via multiple codepaths processed once
if (b.b->getNumArguments() != b.nonLiveArgs.size())
continue;
LDBG() << "Erasing " << b.nonLiveArgs.count()
<< " non-live arguments from block: " << b.b;
// it iterates backwards because erase invalidates all successor indexes
for (int i = b.nonLiveArgs.size() - 1; i >= 0; --i) {
if (!b.nonLiveArgs[i])
continue;
LDBG() << " Erasing block argument " << i << ": " << b.b->getArgument(i);
b.b->getArgument(i).dropAllUses();
b.b->eraseArgument(i);
}
}

// 7. Successor Operands
LDBG() << "Cleaning up " << list.successorOperands.size()
<< " successor operand lists";
for (auto &op : list.successorOperands) {
SuccessorOperands successorOperands =
op.branch.getSuccessorOperands(op.successorIndex);
// blocks that are accessed via multiple codepaths processed once
if (successorOperands.size() != op.nonLiveOperands.size())
continue;
LDBG() << "Erasing " << op.nonLiveOperands.count()
<< " non-live successor operands from successor "
<< op.successorIndex << " of branch: "
<< OpWithFlags(op.branch, OpPrintingFlags().skipRegions());
// it iterates backwards because erase invalidates all successor indexes
for (int i = successorOperands.size() - 1; i >= 0; --i) {
if (!op.nonLiveOperands[i])
continue;
LDBG() << " Erasing successor operand " << i << ": "
<< successorOperands[i];
successorOperands.erase(i);
}
}

LDBG() << "Finished cleanup of dead values";
}

struct RemoveDeadValues : public impl::RemoveDeadValuesBase<RemoveDeadValues> {
Expand Down
20 changes: 20 additions & 0 deletions mlir/test/Analysis/DataFlow/test-liveness-analysis.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -283,3 +283,23 @@ func.func @test_10_negative() -> (i32) {
%0:2 = func.call @private_1() : () -> (i32, i32)
return %0#0 : i32
}

// -----

// Test that we correctly set a liveness value for operations in dead block.
// These won't be visited by the dataflow framework so the analysis need to
// explicitly manage them.
// CHECK-LABEL: test_tag: dead_block_cmpi:
// CHECK-NEXT: operand #0: not live
// CHECK-NEXT: operand #1: not live
// CHECK-NEXT: result #0: not live
func.func @dead_block() {
%false = arith.constant false
%zero = arith.constant 0 : i64
cf.cond_br %false, ^bb1, ^bb4
^bb1:
%3 = arith.cmpi eq, %zero, %zero {tag = "dead_block_cmpi"} : i64
cf.br ^bb1
^bb4:
return
}
21 changes: 21 additions & 0 deletions mlir/test/Transforms/remove-dead-values.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -571,3 +571,24 @@ module @return_void_with_unused_argument {
}
}

// -----

// CHECK-LABEL: module @dynamically_unreachable
module @dynamically_unreachable {
func.func @dynamically_unreachable() {
// This value is used by an operation in a dynamically unreachable block.
%zero = arith.constant 0 : i64

// Dataflow analysis knows from the constant condition that
// ^bb1 is unreachable
%false = arith.constant false
cf.cond_br %false, ^bb1, ^bb4
^bb1:
// This unreachable operation should be removed.
// CHECK-NOT: arith.cmpi
%3 = arith.cmpi eq, %zero, %zero : i64
cf.br ^bb1
^bb4:
return
}
}
1 change: 0 additions & 1 deletion mlir/test/lib/Analysis/DataFlow/TestLivenessAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ struct TestLivenessAnalysisPass

void runOnOperation() override {
auto &livenessAnalysis = getAnalysis<RunLivenessAnalysis>();

Operation *op = getOperation();

raw_ostream &os = llvm::outs();
Expand Down