Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,30 @@ class AbstractSparseForwardDataFlowAnalysis : public DataFlowAnalysis {
/// Join the lattice element and propagate and update if it changed.
void join(AbstractSparseLattice *lhs, const AbstractSparseLattice &rhs);

/// Visits a call operation. Given the operand lattices, sets the result
/// lattices. Performs interprocedural data flow as follows: if the call
/// operation targets an external function, or if the solver is not
/// interprocedural, attempts to infer the results from the call arguments
/// using the user-provided `visitExternalCallImpl`. Otherwise, computes the
/// result lattices from the return sites if all return sites are known;
/// otherwise, conservatively marks the result lattices as having reached
/// their pessimistic fixpoints.
/// This method can be overridden to, for example, be less conservative and
/// propagate the information even if some return sites are unknown.
virtual LogicalResult
visitCallOperation(CallOpInterface call,
ArrayRef<const AbstractSparseLattice *> operandLattices,
ArrayRef<AbstractSparseLattice *> resultLattices);

/// Visits a callable operation. Computes the argument lattices from call
/// sites if all call sites are known; otherwise, conservatively marks them
/// as having reached their pessimistic fixpoints.
/// This method can be overridden to, for example, be less conservative and
/// propagate the information even if some call sites are unknown.
virtual void
visitCallableOperation(CallableOpInterface callable,
ArrayRef<AbstractSparseLattice *> argLattices);

private:
/// Recursively initialize the analysis on nested operations and blocks.
LogicalResult initializeRecursively(Operation *op);
Expand Down Expand Up @@ -430,6 +454,16 @@ class AbstractSparseBackwardDataFlowAnalysis : public DataFlowAnalysis {
/// Join the lattice element and propagate and update if it changed.
void meet(AbstractSparseLattice *lhs, const AbstractSparseLattice &rhs);

/// Visits a callable operation. If all the call sites are known computes the
/// operand lattices of `op` from the result lattices of all the call sites;
/// otherwise, conservatively marks them as having reached their pessimistic
/// fixpoints.
/// This method can be overridden to, for example, be less conservative and
/// propagate the information even if some call sites are unknown.
virtual LogicalResult
visitCallableOperation(Operation *op, CallableOpInterface callable,
ArrayRef<AbstractSparseLattice *> operandLattices);

private:
/// Recursively initialize the analysis on nested operations and blocks.
LogicalResult initializeRecursively(Operation *op);
Expand Down
146 changes: 80 additions & 66 deletions mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -128,34 +128,8 @@ AbstractSparseForwardDataFlowAnalysis::visitOperation(Operation *op) {
operandLattices.push_back(operandLattice);
}

if (auto call = dyn_cast<CallOpInterface>(op)) {
// If the call operation is to an external function, attempt to infer the
// results from the call arguments.
auto callable =
dyn_cast_if_present<CallableOpInterface>(call.resolveCallable());
if (!getSolverConfig().isInterprocedural() ||
(callable && !callable.getCallableRegion())) {
visitExternalCallImpl(call, operandLattices, resultLattices);
return success();
}

// Otherwise, the results of a call operation are determined by the
// callgraph.
const auto *predecessors = getOrCreateFor<PredecessorState>(
getProgramPointAfter(op), getProgramPointAfter(call));
// If not all return sites are known, then conservatively assume we can't
// reason about the data-flow.
if (!predecessors->allPredecessorsKnown()) {
setAllToEntryStates(resultLattices);
return success();
}
for (Operation *predecessor : predecessors->getKnownPredecessors())
for (auto &&[operand, resLattice] :
llvm::zip(predecessor->getOperands(), resultLattices))
join(resLattice,
*getLatticeElementFor(getProgramPointAfter(op), operand));
return success();
}
if (auto call = dyn_cast<CallOpInterface>(op))
return visitCallOperation(call, operandLattices, resultLattices);

// Invoke the operation transfer function.
return visitOperationImpl(op, operandLattices, resultLattices);
Expand Down Expand Up @@ -183,24 +157,8 @@ void AbstractSparseForwardDataFlowAnalysis::visitBlock(Block *block) {
if (block->isEntryBlock()) {
// Check if this block is the entry block of a callable region.
auto callable = dyn_cast<CallableOpInterface>(block->getParentOp());
if (callable && callable.getCallableRegion() == block->getParent()) {
const auto *callsites = getOrCreateFor<PredecessorState>(
getProgramPointBefore(block), getProgramPointAfter(callable));
// If not all callsites are known, conservatively mark all lattices as
// having reached their pessimistic fixpoints.
if (!callsites->allPredecessorsKnown() ||
!getSolverConfig().isInterprocedural()) {
return setAllToEntryStates(argLattices);
}
for (Operation *callsite : callsites->getKnownPredecessors()) {
auto call = cast<CallOpInterface>(callsite);
for (auto it : llvm::zip(call.getArgOperands(), argLattices))
join(std::get<1>(it),
*getLatticeElementFor(getProgramPointBefore(block),
std::get<0>(it)));
}
return;
}
if (callable && callable.getCallableRegion() == block->getParent())
return visitCallableOperation(callable, argLattices);

// Check if the lattices can be determined from region control flow.
if (auto branch = dyn_cast<RegionBranchOpInterface>(block->getParentOp())) {
Expand Down Expand Up @@ -248,6 +206,59 @@ void AbstractSparseForwardDataFlowAnalysis::visitBlock(Block *block) {
}
}

LogicalResult AbstractSparseForwardDataFlowAnalysis::visitCallOperation(
CallOpInterface call,
ArrayRef<const AbstractSparseLattice *> operandLattices,
ArrayRef<AbstractSparseLattice *> resultLattices) {
// If the call operation is to an external function, attempt to infer the
// results from the call arguments.
auto callable =
dyn_cast_if_present<CallableOpInterface>(call.resolveCallable());
if (!getSolverConfig().isInterprocedural() ||
(callable && !callable.getCallableRegion())) {
visitExternalCallImpl(call, operandLattices, resultLattices);
return success();
}

// Otherwise, the results of a call operation are determined by the
// callgraph.
const auto *predecessors = getOrCreateFor<PredecessorState>(
getProgramPointAfter(call), getProgramPointAfter(call));
// If not all return sites are known, then conservatively assume we can't
// reason about the data-flow.
if (!predecessors->allPredecessorsKnown()) {
setAllToEntryStates(resultLattices);
return success();
}
for (Operation *predecessor : predecessors->getKnownPredecessors())
for (auto &&[operand, resLattice] :
llvm::zip(predecessor->getOperands(), resultLattices))
join(resLattice,
*getLatticeElementFor(getProgramPointAfter(call), operand));
return success();
}

void AbstractSparseForwardDataFlowAnalysis::visitCallableOperation(
CallableOpInterface callable,
ArrayRef<AbstractSparseLattice *> argLattices) {
Block *entryBlock = &callable.getCallableRegion()->front();
const auto *callsites = getOrCreateFor<PredecessorState>(
getProgramPointBefore(entryBlock), getProgramPointAfter(callable));
// If not all callsites are known, conservatively mark all lattices as
// having reached their pessimistic fixpoints.
if (!callsites->allPredecessorsKnown() ||
!getSolverConfig().isInterprocedural()) {
return setAllToEntryStates(argLattices);
}
for (Operation *callsite : callsites->getKnownPredecessors()) {
auto call = cast<CallOpInterface>(callsite);
for (auto it : llvm::zip(call.getArgOperands(), argLattices))
join(std::get<1>(it),
*getLatticeElementFor(getProgramPointBefore(entryBlock),
std::get<0>(it)));
}
}

void AbstractSparseForwardDataFlowAnalysis::visitRegionSuccessors(
ProgramPoint *point, RegionBranchOpInterface branch,
RegionBranchPoint successor, ArrayRef<AbstractSparseLattice *> lattices) {
Expand Down Expand Up @@ -512,31 +523,34 @@ AbstractSparseBackwardDataFlowAnalysis::visitOperation(Operation *op) {
if (op->hasTrait<OpTrait::ReturnLike>()) {
// Going backwards, the operands of the return are derived from the
// results of all CallOps calling this CallableOp.
if (auto callable = dyn_cast<CallableOpInterface>(op->getParentOp())) {
const PredecessorState *callsites = getOrCreateFor<PredecessorState>(
getProgramPointAfter(op), getProgramPointAfter(callable));
if (callsites->allPredecessorsKnown()) {
for (Operation *call : callsites->getKnownPredecessors()) {
SmallVector<const AbstractSparseLattice *> callResultLattices =
getLatticeElementsFor(getProgramPointAfter(op),
call->getResults());
for (auto [op, result] :
llvm::zip(operandLattices, callResultLattices))
meet(op, *result);
}
} else {
// If we don't know all the callers, we can't know where the
// returned values go. Note that, in particular, this will trigger
// for the return ops of any public functions.
setAllToExitStates(operandLattices);
}
return success();
}
if (auto callable = dyn_cast<CallableOpInterface>(op->getParentOp()))
return visitCallableOperation(op, callable, operandLattices);
}

return visitOperationImpl(op, operandLattices, resultLattices);
}

LogicalResult AbstractSparseBackwardDataFlowAnalysis::visitCallableOperation(
Operation *op, CallableOpInterface callable,
ArrayRef<AbstractSparseLattice *> operandLattices) {
const PredecessorState *callsites = getOrCreateFor<PredecessorState>(
getProgramPointAfter(op), getProgramPointAfter(callable));
if (callsites->allPredecessorsKnown()) {
for (Operation *call : callsites->getKnownPredecessors()) {
SmallVector<const AbstractSparseLattice *> callResultLattices =
getLatticeElementsFor(getProgramPointAfter(op), call->getResults());
for (auto [op, result] : llvm::zip(operandLattices, callResultLattices))
meet(op, *result);
}
} else {
// If we don't know all the callers, we can't know where the
// returned values go. Note that, in particular, this will trigger
// for the return ops of any public functions.
setAllToExitStates(operandLattices);
}
return success();
}

void AbstractSparseBackwardDataFlowAnalysis::visitRegionSuccessors(
RegionBranchOpInterface branch,
ArrayRef<AbstractSparseLattice *> operandLattices) {
Expand Down
Loading