-
Notifications
You must be signed in to change notification settings - Fork 15.4k
[mlir] NFC: Add data flow analysis extension points #142549
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir] NFC: Add data flow analysis extension points #142549
Conversation
This commit introduces `visitCallOperation` and `visitCallableOperation` extension points in the sparse data flow analysis framework. This allows, for example, to make the analysis less conservative, without a lot of code duplication, propagating information even if not all the call or return sites are known.
|
@llvm/pr-subscribers-mlir Author: Vadim Curcă (VadimCurca) ChangesThis commit introduces Full diff: https://github.com/llvm/llvm-project/pull/142549.diff 2 Files Affected:
diff --git a/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
index 1b2c679176107..3f8874d02afad 100644
--- a/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
+++ b/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
@@ -235,6 +235,30 @@ class AbstractSparseForwardDataFlowAnalysis : public DataFlowAnalysis {
/// Join the lattice element and propagate and update if it changed.
void join(AbstractSparseLattice *lhs, const AbstractSparseLattice &rhs);
+ /// Visits a call operation. Given the operand lattices, sets the result
+ /// lattices. Performs interprocedural data flow as follows: if the call
+ /// operation targets an external function, or if the solver is not
+ /// interprocedural, attempts to infer the results from the call arguments
+ /// using the user-provided `visitExternalCallImpl`. Otherwise, computes the
+ /// result lattices from the return sites if all return sites are known;
+ /// otherwise, conservatively marks the result lattices as having reached
+ /// their pessimistic fixpoints.
+ /// This method can be overridden to, for example, be less conservative and
+ /// propagate the information even if some return sites are unknown.
+ virtual LogicalResult
+ visitCallOperation(CallOpInterface call,
+ ArrayRef<const AbstractSparseLattice *> operandLattices,
+ ArrayRef<AbstractSparseLattice *> resultLattices);
+
+ /// Visits a callable operation. Computes the argument lattices from call
+ /// sites if all call sites are known; otherwise, conservatively marks them
+ /// as having reached their pessimistic fixpoints.
+ /// This method can be overridden to, for example, be less conservative and
+ /// propagate the information even if some call sites are unknown.
+ virtual void
+ visitCallableOperation(CallableOpInterface callable,
+ ArrayRef<AbstractSparseLattice *> argLattices);
+
private:
/// Recursively initialize the analysis on nested operations and blocks.
LogicalResult initializeRecursively(Operation *op);
@@ -430,6 +454,16 @@ class AbstractSparseBackwardDataFlowAnalysis : public DataFlowAnalysis {
/// Join the lattice element and propagate and update if it changed.
void meet(AbstractSparseLattice *lhs, const AbstractSparseLattice &rhs);
+ /// Visits a callable operation. If all the call sites are known computes the
+ /// operand lattices of `op` from the result lattices of all the call sites;
+ /// otherwise, conservatively marks them as having reached their pessimistic
+ /// fixpoints.
+ /// This method can be overridden to, for example, be less conservative and
+ /// propagate the information even if some call sites are unknown.
+ virtual LogicalResult
+ visitCallableOperation(Operation *op, CallableOpInterface callable,
+ ArrayRef<AbstractSparseLattice *> operandLattices);
+
private:
/// Recursively initialize the analysis on nested operations and blocks.
LogicalResult initializeRecursively(Operation *op);
diff --git a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
index 0b39d14042493..016e59dcb744e 100644
--- a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
@@ -128,34 +128,8 @@ AbstractSparseForwardDataFlowAnalysis::visitOperation(Operation *op) {
operandLattices.push_back(operandLattice);
}
- if (auto call = dyn_cast<CallOpInterface>(op)) {
- // If the call operation is to an external function, attempt to infer the
- // results from the call arguments.
- auto callable =
- dyn_cast_if_present<CallableOpInterface>(call.resolveCallable());
- if (!getSolverConfig().isInterprocedural() ||
- (callable && !callable.getCallableRegion())) {
- visitExternalCallImpl(call, operandLattices, resultLattices);
- return success();
- }
-
- // Otherwise, the results of a call operation are determined by the
- // callgraph.
- const auto *predecessors = getOrCreateFor<PredecessorState>(
- getProgramPointAfter(op), getProgramPointAfter(call));
- // If not all return sites are known, then conservatively assume we can't
- // reason about the data-flow.
- if (!predecessors->allPredecessorsKnown()) {
- setAllToEntryStates(resultLattices);
- return success();
- }
- for (Operation *predecessor : predecessors->getKnownPredecessors())
- for (auto &&[operand, resLattice] :
- llvm::zip(predecessor->getOperands(), resultLattices))
- join(resLattice,
- *getLatticeElementFor(getProgramPointAfter(op), operand));
- return success();
- }
+ if (auto call = dyn_cast<CallOpInterface>(op))
+ return visitCallOperation(call, operandLattices, resultLattices);
// Invoke the operation transfer function.
return visitOperationImpl(op, operandLattices, resultLattices);
@@ -183,24 +157,8 @@ void AbstractSparseForwardDataFlowAnalysis::visitBlock(Block *block) {
if (block->isEntryBlock()) {
// Check if this block is the entry block of a callable region.
auto callable = dyn_cast<CallableOpInterface>(block->getParentOp());
- if (callable && callable.getCallableRegion() == block->getParent()) {
- const auto *callsites = getOrCreateFor<PredecessorState>(
- getProgramPointBefore(block), getProgramPointAfter(callable));
- // If not all callsites are known, conservatively mark all lattices as
- // having reached their pessimistic fixpoints.
- if (!callsites->allPredecessorsKnown() ||
- !getSolverConfig().isInterprocedural()) {
- return setAllToEntryStates(argLattices);
- }
- for (Operation *callsite : callsites->getKnownPredecessors()) {
- auto call = cast<CallOpInterface>(callsite);
- for (auto it : llvm::zip(call.getArgOperands(), argLattices))
- join(std::get<1>(it),
- *getLatticeElementFor(getProgramPointBefore(block),
- std::get<0>(it)));
- }
- return;
- }
+ if (callable && callable.getCallableRegion() == block->getParent())
+ return visitCallableOperation(callable, argLattices);
// Check if the lattices can be determined from region control flow.
if (auto branch = dyn_cast<RegionBranchOpInterface>(block->getParentOp())) {
@@ -248,6 +206,59 @@ void AbstractSparseForwardDataFlowAnalysis::visitBlock(Block *block) {
}
}
+LogicalResult AbstractSparseForwardDataFlowAnalysis::visitCallOperation(
+ CallOpInterface call,
+ ArrayRef<const AbstractSparseLattice *> operandLattices,
+ ArrayRef<AbstractSparseLattice *> resultLattices) {
+ // If the call operation is to an external function, attempt to infer the
+ // results from the call arguments.
+ auto callable =
+ dyn_cast_if_present<CallableOpInterface>(call.resolveCallable());
+ if (!getSolverConfig().isInterprocedural() ||
+ (callable && !callable.getCallableRegion())) {
+ visitExternalCallImpl(call, operandLattices, resultLattices);
+ return success();
+ }
+
+ // Otherwise, the results of a call operation are determined by the
+ // callgraph.
+ const auto *predecessors = getOrCreateFor<PredecessorState>(
+ getProgramPointAfter(call), getProgramPointAfter(call));
+ // If not all return sites are known, then conservatively assume we can't
+ // reason about the data-flow.
+ if (!predecessors->allPredecessorsKnown()) {
+ setAllToEntryStates(resultLattices);
+ return success();
+ }
+ for (Operation *predecessor : predecessors->getKnownPredecessors())
+ for (auto &&[operand, resLattice] :
+ llvm::zip(predecessor->getOperands(), resultLattices))
+ join(resLattice,
+ *getLatticeElementFor(getProgramPointAfter(call), operand));
+ return success();
+}
+
+void AbstractSparseForwardDataFlowAnalysis::visitCallableOperation(
+ CallableOpInterface callable,
+ ArrayRef<AbstractSparseLattice *> argLattices) {
+ Block *entryBlock = &callable.getCallableRegion()->front();
+ const auto *callsites = getOrCreateFor<PredecessorState>(
+ getProgramPointBefore(entryBlock), getProgramPointAfter(callable));
+ // If not all callsites are known, conservatively mark all lattices as
+ // having reached their pessimistic fixpoints.
+ if (!callsites->allPredecessorsKnown() ||
+ !getSolverConfig().isInterprocedural()) {
+ return setAllToEntryStates(argLattices);
+ }
+ for (Operation *callsite : callsites->getKnownPredecessors()) {
+ auto call = cast<CallOpInterface>(callsite);
+ for (auto it : llvm::zip(call.getArgOperands(), argLattices))
+ join(std::get<1>(it),
+ *getLatticeElementFor(getProgramPointBefore(entryBlock),
+ std::get<0>(it)));
+ }
+}
+
void AbstractSparseForwardDataFlowAnalysis::visitRegionSuccessors(
ProgramPoint *point, RegionBranchOpInterface branch,
RegionBranchPoint successor, ArrayRef<AbstractSparseLattice *> lattices) {
@@ -512,31 +523,34 @@ AbstractSparseBackwardDataFlowAnalysis::visitOperation(Operation *op) {
if (op->hasTrait<OpTrait::ReturnLike>()) {
// Going backwards, the operands of the return are derived from the
// results of all CallOps calling this CallableOp.
- if (auto callable = dyn_cast<CallableOpInterface>(op->getParentOp())) {
- const PredecessorState *callsites = getOrCreateFor<PredecessorState>(
- getProgramPointAfter(op), getProgramPointAfter(callable));
- if (callsites->allPredecessorsKnown()) {
- for (Operation *call : callsites->getKnownPredecessors()) {
- SmallVector<const AbstractSparseLattice *> callResultLattices =
- getLatticeElementsFor(getProgramPointAfter(op),
- call->getResults());
- for (auto [op, result] :
- llvm::zip(operandLattices, callResultLattices))
- meet(op, *result);
- }
- } else {
- // If we don't know all the callers, we can't know where the
- // returned values go. Note that, in particular, this will trigger
- // for the return ops of any public functions.
- setAllToExitStates(operandLattices);
- }
- return success();
- }
+ if (auto callable = dyn_cast<CallableOpInterface>(op->getParentOp()))
+ return visitCallableOperation(op, callable, operandLattices);
}
return visitOperationImpl(op, operandLattices, resultLattices);
}
+LogicalResult AbstractSparseBackwardDataFlowAnalysis::visitCallableOperation(
+ Operation *op, CallableOpInterface callable,
+ ArrayRef<AbstractSparseLattice *> operandLattices) {
+ const PredecessorState *callsites = getOrCreateFor<PredecessorState>(
+ getProgramPointAfter(op), getProgramPointAfter(callable));
+ if (callsites->allPredecessorsKnown()) {
+ for (Operation *call : callsites->getKnownPredecessors()) {
+ SmallVector<const AbstractSparseLattice *> callResultLattices =
+ getLatticeElementsFor(getProgramPointAfter(op), call->getResults());
+ for (auto [op, result] : llvm::zip(operandLattices, callResultLattices))
+ meet(op, *result);
+ }
+ } else {
+ // If we don't know all the callers, we can't know where the
+ // returned values go. Note that, in particular, this will trigger
+ // for the return ops of any public functions.
+ setAllToExitStates(operandLattices);
+ }
+ return success();
+}
+
void AbstractSparseBackwardDataFlowAnalysis::visitRegionSuccessors(
RegionBranchOpInterface branch,
ArrayRef<AbstractSparseLattice *> operandLattices) {
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
EDIT: looking again - this is a good refactor that simplifies/explains visitOperation well.
This commit introduces
visitCallOperationandvisitCallableOperationextension points in the sparse data flow analysis framework. This allows, for example, to make the analysis less conservative, without a lot of code duplication, propagating information even if not all the call or return sites are known.