diff --git a/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h index 1b2c679176107..3f8874d02afad 100644 --- a/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h +++ b/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h @@ -235,6 +235,30 @@ class AbstractSparseForwardDataFlowAnalysis : public DataFlowAnalysis { /// Join the lattice element and propagate and update if it changed. void join(AbstractSparseLattice *lhs, const AbstractSparseLattice &rhs); + /// Visits a call operation. Given the operand lattices, sets the result + /// lattices. Performs interprocedural data flow as follows: if the call + /// operation targets an external function, or if the solver is not + /// interprocedural, attempts to infer the results from the call arguments + /// using the user-provided `visitExternalCallImpl`. Otherwise, computes the + /// result lattices from the return sites if all return sites are known; + /// otherwise, conservatively marks the result lattices as having reached + /// their pessimistic fixpoints. + /// This method can be overridden to, for example, be less conservative and + /// propagate the information even if some return sites are unknown. + virtual LogicalResult + visitCallOperation(CallOpInterface call, + ArrayRef operandLattices, + ArrayRef resultLattices); + + /// Visits a callable operation. Computes the argument lattices from call + /// sites if all call sites are known; otherwise, conservatively marks them + /// as having reached their pessimistic fixpoints. + /// This method can be overridden to, for example, be less conservative and + /// propagate the information even if some call sites are unknown. + virtual void + visitCallableOperation(CallableOpInterface callable, + ArrayRef argLattices); + private: /// Recursively initialize the analysis on nested operations and blocks. LogicalResult initializeRecursively(Operation *op); @@ -430,6 +454,16 @@ class AbstractSparseBackwardDataFlowAnalysis : public DataFlowAnalysis { /// Join the lattice element and propagate and update if it changed. void meet(AbstractSparseLattice *lhs, const AbstractSparseLattice &rhs); + /// Visits a callable operation. If all the call sites are known computes the + /// operand lattices of `op` from the result lattices of all the call sites; + /// otherwise, conservatively marks them as having reached their pessimistic + /// fixpoints. + /// This method can be overridden to, for example, be less conservative and + /// propagate the information even if some call sites are unknown. + virtual LogicalResult + visitCallableOperation(Operation *op, CallableOpInterface callable, + ArrayRef operandLattices); + private: /// Recursively initialize the analysis on nested operations and blocks. LogicalResult initializeRecursively(Operation *op); diff --git a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp index 0b39d14042493..016e59dcb744e 100644 --- a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp +++ b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp @@ -128,34 +128,8 @@ AbstractSparseForwardDataFlowAnalysis::visitOperation(Operation *op) { operandLattices.push_back(operandLattice); } - if (auto call = dyn_cast(op)) { - // If the call operation is to an external function, attempt to infer the - // results from the call arguments. - auto callable = - dyn_cast_if_present(call.resolveCallable()); - if (!getSolverConfig().isInterprocedural() || - (callable && !callable.getCallableRegion())) { - visitExternalCallImpl(call, operandLattices, resultLattices); - return success(); - } - - // Otherwise, the results of a call operation are determined by the - // callgraph. - const auto *predecessors = getOrCreateFor( - getProgramPointAfter(op), getProgramPointAfter(call)); - // If not all return sites are known, then conservatively assume we can't - // reason about the data-flow. - if (!predecessors->allPredecessorsKnown()) { - setAllToEntryStates(resultLattices); - return success(); - } - for (Operation *predecessor : predecessors->getKnownPredecessors()) - for (auto &&[operand, resLattice] : - llvm::zip(predecessor->getOperands(), resultLattices)) - join(resLattice, - *getLatticeElementFor(getProgramPointAfter(op), operand)); - return success(); - } + if (auto call = dyn_cast(op)) + return visitCallOperation(call, operandLattices, resultLattices); // Invoke the operation transfer function. return visitOperationImpl(op, operandLattices, resultLattices); @@ -183,24 +157,8 @@ void AbstractSparseForwardDataFlowAnalysis::visitBlock(Block *block) { if (block->isEntryBlock()) { // Check if this block is the entry block of a callable region. auto callable = dyn_cast(block->getParentOp()); - if (callable && callable.getCallableRegion() == block->getParent()) { - const auto *callsites = getOrCreateFor( - getProgramPointBefore(block), getProgramPointAfter(callable)); - // If not all callsites are known, conservatively mark all lattices as - // having reached their pessimistic fixpoints. - if (!callsites->allPredecessorsKnown() || - !getSolverConfig().isInterprocedural()) { - return setAllToEntryStates(argLattices); - } - for (Operation *callsite : callsites->getKnownPredecessors()) { - auto call = cast(callsite); - for (auto it : llvm::zip(call.getArgOperands(), argLattices)) - join(std::get<1>(it), - *getLatticeElementFor(getProgramPointBefore(block), - std::get<0>(it))); - } - return; - } + if (callable && callable.getCallableRegion() == block->getParent()) + return visitCallableOperation(callable, argLattices); // Check if the lattices can be determined from region control flow. if (auto branch = dyn_cast(block->getParentOp())) { @@ -248,6 +206,59 @@ void AbstractSparseForwardDataFlowAnalysis::visitBlock(Block *block) { } } +LogicalResult AbstractSparseForwardDataFlowAnalysis::visitCallOperation( + CallOpInterface call, + ArrayRef operandLattices, + ArrayRef resultLattices) { + // If the call operation is to an external function, attempt to infer the + // results from the call arguments. + auto callable = + dyn_cast_if_present(call.resolveCallable()); + if (!getSolverConfig().isInterprocedural() || + (callable && !callable.getCallableRegion())) { + visitExternalCallImpl(call, operandLattices, resultLattices); + return success(); + } + + // Otherwise, the results of a call operation are determined by the + // callgraph. + const auto *predecessors = getOrCreateFor( + getProgramPointAfter(call), getProgramPointAfter(call)); + // If not all return sites are known, then conservatively assume we can't + // reason about the data-flow. + if (!predecessors->allPredecessorsKnown()) { + setAllToEntryStates(resultLattices); + return success(); + } + for (Operation *predecessor : predecessors->getKnownPredecessors()) + for (auto &&[operand, resLattice] : + llvm::zip(predecessor->getOperands(), resultLattices)) + join(resLattice, + *getLatticeElementFor(getProgramPointAfter(call), operand)); + return success(); +} + +void AbstractSparseForwardDataFlowAnalysis::visitCallableOperation( + CallableOpInterface callable, + ArrayRef argLattices) { + Block *entryBlock = &callable.getCallableRegion()->front(); + const auto *callsites = getOrCreateFor( + getProgramPointBefore(entryBlock), getProgramPointAfter(callable)); + // If not all callsites are known, conservatively mark all lattices as + // having reached their pessimistic fixpoints. + if (!callsites->allPredecessorsKnown() || + !getSolverConfig().isInterprocedural()) { + return setAllToEntryStates(argLattices); + } + for (Operation *callsite : callsites->getKnownPredecessors()) { + auto call = cast(callsite); + for (auto it : llvm::zip(call.getArgOperands(), argLattices)) + join(std::get<1>(it), + *getLatticeElementFor(getProgramPointBefore(entryBlock), + std::get<0>(it))); + } +} + void AbstractSparseForwardDataFlowAnalysis::visitRegionSuccessors( ProgramPoint *point, RegionBranchOpInterface branch, RegionBranchPoint successor, ArrayRef lattices) { @@ -512,31 +523,34 @@ AbstractSparseBackwardDataFlowAnalysis::visitOperation(Operation *op) { if (op->hasTrait()) { // Going backwards, the operands of the return are derived from the // results of all CallOps calling this CallableOp. - if (auto callable = dyn_cast(op->getParentOp())) { - const PredecessorState *callsites = getOrCreateFor( - getProgramPointAfter(op), getProgramPointAfter(callable)); - if (callsites->allPredecessorsKnown()) { - for (Operation *call : callsites->getKnownPredecessors()) { - SmallVector callResultLattices = - getLatticeElementsFor(getProgramPointAfter(op), - call->getResults()); - for (auto [op, result] : - llvm::zip(operandLattices, callResultLattices)) - meet(op, *result); - } - } else { - // If we don't know all the callers, we can't know where the - // returned values go. Note that, in particular, this will trigger - // for the return ops of any public functions. - setAllToExitStates(operandLattices); - } - return success(); - } + if (auto callable = dyn_cast(op->getParentOp())) + return visitCallableOperation(op, callable, operandLattices); } return visitOperationImpl(op, operandLattices, resultLattices); } +LogicalResult AbstractSparseBackwardDataFlowAnalysis::visitCallableOperation( + Operation *op, CallableOpInterface callable, + ArrayRef operandLattices) { + const PredecessorState *callsites = getOrCreateFor( + getProgramPointAfter(op), getProgramPointAfter(callable)); + if (callsites->allPredecessorsKnown()) { + for (Operation *call : callsites->getKnownPredecessors()) { + SmallVector callResultLattices = + getLatticeElementsFor(getProgramPointAfter(op), call->getResults()); + for (auto [op, result] : llvm::zip(operandLattices, callResultLattices)) + meet(op, *result); + } + } else { + // If we don't know all the callers, we can't know where the + // returned values go. Note that, in particular, this will trigger + // for the return ops of any public functions. + setAllToExitStates(operandLattices); + } + return success(); +} + void AbstractSparseBackwardDataFlowAnalysis::visitRegionSuccessors( RegionBranchOpInterface branch, ArrayRef operandLattices) {