@@ -128,34 +128,8 @@ AbstractSparseForwardDataFlowAnalysis::visitOperation(Operation *op) {
128128 operandLattices.push_back (operandLattice);
129129 }
130130
131- if (auto call = dyn_cast<CallOpInterface>(op)) {
132- // If the call operation is to an external function, attempt to infer the
133- // results from the call arguments.
134- auto callable =
135- dyn_cast_if_present<CallableOpInterface>(call.resolveCallable ());
136- if (!getSolverConfig ().isInterprocedural () ||
137- (callable && !callable.getCallableRegion ())) {
138- visitExternalCallImpl (call, operandLattices, resultLattices);
139- return success ();
140- }
141-
142- // Otherwise, the results of a call operation are determined by the
143- // callgraph.
144- const auto *predecessors = getOrCreateFor<PredecessorState>(
145- getProgramPointAfter (op), getProgramPointAfter (call));
146- // If not all return sites are known, then conservatively assume we can't
147- // reason about the data-flow.
148- if (!predecessors->allPredecessorsKnown ()) {
149- setAllToEntryStates (resultLattices);
150- return success ();
151- }
152- for (Operation *predecessor : predecessors->getKnownPredecessors ())
153- for (auto &&[operand, resLattice] :
154- llvm::zip (predecessor->getOperands (), resultLattices))
155- join (resLattice,
156- *getLatticeElementFor (getProgramPointAfter (op), operand));
157- return success ();
158- }
131+ if (auto call = dyn_cast<CallOpInterface>(op))
132+ return visitCallOperation (call, operandLattices, resultLattices);
159133
160134 // Invoke the operation transfer function.
161135 return visitOperationImpl (op, operandLattices, resultLattices);
@@ -183,24 +157,8 @@ void AbstractSparseForwardDataFlowAnalysis::visitBlock(Block *block) {
183157 if (block->isEntryBlock ()) {
184158 // Check if this block is the entry block of a callable region.
185159 auto callable = dyn_cast<CallableOpInterface>(block->getParentOp ());
186- if (callable && callable.getCallableRegion () == block->getParent ()) {
187- const auto *callsites = getOrCreateFor<PredecessorState>(
188- getProgramPointBefore (block), getProgramPointAfter (callable));
189- // If not all callsites are known, conservatively mark all lattices as
190- // having reached their pessimistic fixpoints.
191- if (!callsites->allPredecessorsKnown () ||
192- !getSolverConfig ().isInterprocedural ()) {
193- return setAllToEntryStates (argLattices);
194- }
195- for (Operation *callsite : callsites->getKnownPredecessors ()) {
196- auto call = cast<CallOpInterface>(callsite);
197- for (auto it : llvm::zip (call.getArgOperands (), argLattices))
198- join (std::get<1 >(it),
199- *getLatticeElementFor (getProgramPointBefore (block),
200- std::get<0 >(it)));
201- }
202- return ;
203- }
160+ if (callable && callable.getCallableRegion () == block->getParent ())
161+ return visitCallableOperation (callable, argLattices);
204162
205163 // Check if the lattices can be determined from region control flow.
206164 if (auto branch = dyn_cast<RegionBranchOpInterface>(block->getParentOp ())) {
@@ -248,6 +206,59 @@ void AbstractSparseForwardDataFlowAnalysis::visitBlock(Block *block) {
248206 }
249207}
250208
209+ LogicalResult AbstractSparseForwardDataFlowAnalysis::visitCallOperation (
210+ CallOpInterface call,
211+ ArrayRef<const AbstractSparseLattice *> operandLattices,
212+ ArrayRef<AbstractSparseLattice *> resultLattices) {
213+ // If the call operation is to an external function, attempt to infer the
214+ // results from the call arguments.
215+ auto callable =
216+ dyn_cast_if_present<CallableOpInterface>(call.resolveCallable ());
217+ if (!getSolverConfig ().isInterprocedural () ||
218+ (callable && !callable.getCallableRegion ())) {
219+ visitExternalCallImpl (call, operandLattices, resultLattices);
220+ return success ();
221+ }
222+
223+ // Otherwise, the results of a call operation are determined by the
224+ // callgraph.
225+ const auto *predecessors = getOrCreateFor<PredecessorState>(
226+ getProgramPointAfter (call), getProgramPointAfter (call));
227+ // If not all return sites are known, then conservatively assume we can't
228+ // reason about the data-flow.
229+ if (!predecessors->allPredecessorsKnown ()) {
230+ setAllToEntryStates (resultLattices);
231+ return success ();
232+ }
233+ for (Operation *predecessor : predecessors->getKnownPredecessors ())
234+ for (auto &&[operand, resLattice] :
235+ llvm::zip (predecessor->getOperands (), resultLattices))
236+ join (resLattice,
237+ *getLatticeElementFor (getProgramPointAfter (call), operand));
238+ return success ();
239+ }
240+
241+ void AbstractSparseForwardDataFlowAnalysis::visitCallableOperation (
242+ CallableOpInterface callable,
243+ ArrayRef<AbstractSparseLattice *> argLattices) {
244+ Block *entryBlock = &callable.getCallableRegion ()->front ();
245+ const auto *callsites = getOrCreateFor<PredecessorState>(
246+ getProgramPointBefore (entryBlock), getProgramPointAfter (callable));
247+ // If not all callsites are known, conservatively mark all lattices as
248+ // having reached their pessimistic fixpoints.
249+ if (!callsites->allPredecessorsKnown () ||
250+ !getSolverConfig ().isInterprocedural ()) {
251+ return setAllToEntryStates (argLattices);
252+ }
253+ for (Operation *callsite : callsites->getKnownPredecessors ()) {
254+ auto call = cast<CallOpInterface>(callsite);
255+ for (auto it : llvm::zip (call.getArgOperands (), argLattices))
256+ join (std::get<1 >(it),
257+ *getLatticeElementFor (getProgramPointBefore (entryBlock),
258+ std::get<0 >(it)));
259+ }
260+ }
261+
251262void AbstractSparseForwardDataFlowAnalysis::visitRegionSuccessors (
252263 ProgramPoint *point, RegionBranchOpInterface branch,
253264 RegionBranchPoint successor, ArrayRef<AbstractSparseLattice *> lattices) {
@@ -512,31 +523,34 @@ AbstractSparseBackwardDataFlowAnalysis::visitOperation(Operation *op) {
512523 if (op->hasTrait <OpTrait::ReturnLike>()) {
513524 // Going backwards, the operands of the return are derived from the
514525 // results of all CallOps calling this CallableOp.
515- if (auto callable = dyn_cast<CallableOpInterface>(op->getParentOp ())) {
516- const PredecessorState *callsites = getOrCreateFor<PredecessorState>(
517- getProgramPointAfter (op), getProgramPointAfter (callable));
518- if (callsites->allPredecessorsKnown ()) {
519- for (Operation *call : callsites->getKnownPredecessors ()) {
520- SmallVector<const AbstractSparseLattice *> callResultLattices =
521- getLatticeElementsFor (getProgramPointAfter (op),
522- call->getResults ());
523- for (auto [op, result] :
524- llvm::zip (operandLattices, callResultLattices))
525- meet (op, *result);
526- }
527- } else {
528- // If we don't know all the callers, we can't know where the
529- // returned values go. Note that, in particular, this will trigger
530- // for the return ops of any public functions.
531- setAllToExitStates (operandLattices);
532- }
533- return success ();
534- }
526+ if (auto callable = dyn_cast<CallableOpInterface>(op->getParentOp ()))
527+ return visitCallableOperation (op, callable, operandLattices);
535528 }
536529
537530 return visitOperationImpl (op, operandLattices, resultLattices);
538531}
539532
533+ LogicalResult AbstractSparseBackwardDataFlowAnalysis::visitCallableOperation (
534+ Operation *op, CallableOpInterface callable,
535+ ArrayRef<AbstractSparseLattice *> operandLattices) {
536+ const PredecessorState *callsites = getOrCreateFor<PredecessorState>(
537+ getProgramPointAfter (op), getProgramPointAfter (callable));
538+ if (callsites->allPredecessorsKnown ()) {
539+ for (Operation *call : callsites->getKnownPredecessors ()) {
540+ SmallVector<const AbstractSparseLattice *> callResultLattices =
541+ getLatticeElementsFor (getProgramPointAfter (op), call->getResults ());
542+ for (auto [op, result] : llvm::zip (operandLattices, callResultLattices))
543+ meet (op, *result);
544+ }
545+ } else {
546+ // If we don't know all the callers, we can't know where the
547+ // returned values go. Note that, in particular, this will trigger
548+ // for the return ops of any public functions.
549+ setAllToExitStates (operandLattices);
550+ }
551+ return success ();
552+ }
553+
540554void AbstractSparseBackwardDataFlowAnalysis::visitRegionSuccessors (
541555 RegionBranchOpInterface branch,
542556 ArrayRef<AbstractSparseLattice *> operandLattices) {
0 commit comments