Skip to content

Conversation

@VadimCurca
Copy link
Contributor

This commit introduces visitCallOperation and visitCallableOperation extension points in the sparse data flow analysis framework. This allows, for example, to make the analysis less conservative, without a lot of code duplication, propagating information even if not all the call or return sites are known.

This commit introduces `visitCallOperation` and `visitCallableOperation`
extension points in the sparse data flow analysis framework. This
allows, for example, to make the analysis less conservative, without a
lot of code duplication, propagating information even if not all the
call or return sites are known.
@llvmbot llvmbot added the mlir label Jun 3, 2025
@llvmbot
Copy link
Member

llvmbot commented Jun 3, 2025

@llvm/pr-subscribers-mlir

Author: Vadim Curcă (VadimCurca)

Changes

This commit introduces visitCallOperation and visitCallableOperation extension points in the sparse data flow analysis framework. This allows, for example, to make the analysis less conservative, without a lot of code duplication, propagating information even if not all the call or return sites are known.


Full diff: https://github.com/llvm/llvm-project/pull/142549.diff

2 Files Affected:

  • (modified) mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h (+34)
  • (modified) mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp (+80-66)
diff --git a/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
index 1b2c679176107..3f8874d02afad 100644
--- a/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
+++ b/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
@@ -235,6 +235,30 @@ class AbstractSparseForwardDataFlowAnalysis : public DataFlowAnalysis {
   /// Join the lattice element and propagate and update if it changed.
   void join(AbstractSparseLattice *lhs, const AbstractSparseLattice &rhs);
 
+  /// Visits a call operation. Given the operand lattices, sets the result
+  /// lattices. Performs interprocedural data flow as follows: if the call
+  /// operation targets an external function, or if the solver is not
+  /// interprocedural, attempts to infer the results from the call arguments
+  /// using the user-provided `visitExternalCallImpl`. Otherwise, computes the
+  /// result lattices from the return sites if all return sites are known;
+  /// otherwise, conservatively marks the result lattices as having reached
+  /// their pessimistic fixpoints.
+  /// This method can be overridden to, for example, be less conservative and
+  /// propagate the information even if some return sites are unknown.
+  virtual LogicalResult
+  visitCallOperation(CallOpInterface call,
+                     ArrayRef<const AbstractSparseLattice *> operandLattices,
+                     ArrayRef<AbstractSparseLattice *> resultLattices);
+
+  /// Visits a callable operation. Computes the argument lattices from call
+  /// sites if all call sites are known; otherwise, conservatively marks them
+  /// as having reached their pessimistic fixpoints.
+  /// This method can be overridden to, for example, be less conservative and
+  /// propagate the information even if some call sites are unknown.
+  virtual void
+  visitCallableOperation(CallableOpInterface callable,
+                         ArrayRef<AbstractSparseLattice *> argLattices);
+
 private:
   /// Recursively initialize the analysis on nested operations and blocks.
   LogicalResult initializeRecursively(Operation *op);
@@ -430,6 +454,16 @@ class AbstractSparseBackwardDataFlowAnalysis : public DataFlowAnalysis {
   /// Join the lattice element and propagate and update if it changed.
   void meet(AbstractSparseLattice *lhs, const AbstractSparseLattice &rhs);
 
+  /// Visits a callable operation. If all the call sites are known computes the
+  /// operand lattices of `op` from the result lattices of all the call sites;
+  /// otherwise, conservatively marks them as having reached their pessimistic
+  /// fixpoints.
+  /// This method can be overridden to, for example, be less conservative and
+  /// propagate the information even if some call sites are unknown.
+  virtual LogicalResult
+  visitCallableOperation(Operation *op, CallableOpInterface callable,
+                         ArrayRef<AbstractSparseLattice *> operandLattices);
+
 private:
   /// Recursively initialize the analysis on nested operations and blocks.
   LogicalResult initializeRecursively(Operation *op);
diff --git a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
index 0b39d14042493..016e59dcb744e 100644
--- a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
@@ -128,34 +128,8 @@ AbstractSparseForwardDataFlowAnalysis::visitOperation(Operation *op) {
     operandLattices.push_back(operandLattice);
   }
 
-  if (auto call = dyn_cast<CallOpInterface>(op)) {
-    // If the call operation is to an external function, attempt to infer the
-    // results from the call arguments.
-    auto callable =
-        dyn_cast_if_present<CallableOpInterface>(call.resolveCallable());
-    if (!getSolverConfig().isInterprocedural() ||
-        (callable && !callable.getCallableRegion())) {
-      visitExternalCallImpl(call, operandLattices, resultLattices);
-      return success();
-    }
-
-    // Otherwise, the results of a call operation are determined by the
-    // callgraph.
-    const auto *predecessors = getOrCreateFor<PredecessorState>(
-        getProgramPointAfter(op), getProgramPointAfter(call));
-    // If not all return sites are known, then conservatively assume we can't
-    // reason about the data-flow.
-    if (!predecessors->allPredecessorsKnown()) {
-      setAllToEntryStates(resultLattices);
-      return success();
-    }
-    for (Operation *predecessor : predecessors->getKnownPredecessors())
-      for (auto &&[operand, resLattice] :
-           llvm::zip(predecessor->getOperands(), resultLattices))
-        join(resLattice,
-             *getLatticeElementFor(getProgramPointAfter(op), operand));
-    return success();
-  }
+  if (auto call = dyn_cast<CallOpInterface>(op))
+    return visitCallOperation(call, operandLattices, resultLattices);
 
   // Invoke the operation transfer function.
   return visitOperationImpl(op, operandLattices, resultLattices);
@@ -183,24 +157,8 @@ void AbstractSparseForwardDataFlowAnalysis::visitBlock(Block *block) {
   if (block->isEntryBlock()) {
     // Check if this block is the entry block of a callable region.
     auto callable = dyn_cast<CallableOpInterface>(block->getParentOp());
-    if (callable && callable.getCallableRegion() == block->getParent()) {
-      const auto *callsites = getOrCreateFor<PredecessorState>(
-          getProgramPointBefore(block), getProgramPointAfter(callable));
-      // If not all callsites are known, conservatively mark all lattices as
-      // having reached their pessimistic fixpoints.
-      if (!callsites->allPredecessorsKnown() ||
-          !getSolverConfig().isInterprocedural()) {
-        return setAllToEntryStates(argLattices);
-      }
-      for (Operation *callsite : callsites->getKnownPredecessors()) {
-        auto call = cast<CallOpInterface>(callsite);
-        for (auto it : llvm::zip(call.getArgOperands(), argLattices))
-          join(std::get<1>(it),
-               *getLatticeElementFor(getProgramPointBefore(block),
-                                     std::get<0>(it)));
-      }
-      return;
-    }
+    if (callable && callable.getCallableRegion() == block->getParent())
+      return visitCallableOperation(callable, argLattices);
 
     // Check if the lattices can be determined from region control flow.
     if (auto branch = dyn_cast<RegionBranchOpInterface>(block->getParentOp())) {
@@ -248,6 +206,59 @@ void AbstractSparseForwardDataFlowAnalysis::visitBlock(Block *block) {
   }
 }
 
+LogicalResult AbstractSparseForwardDataFlowAnalysis::visitCallOperation(
+    CallOpInterface call,
+    ArrayRef<const AbstractSparseLattice *> operandLattices,
+    ArrayRef<AbstractSparseLattice *> resultLattices) {
+  // If the call operation is to an external function, attempt to infer the
+  // results from the call arguments.
+  auto callable =
+      dyn_cast_if_present<CallableOpInterface>(call.resolveCallable());
+  if (!getSolverConfig().isInterprocedural() ||
+      (callable && !callable.getCallableRegion())) {
+    visitExternalCallImpl(call, operandLattices, resultLattices);
+    return success();
+  }
+
+  // Otherwise, the results of a call operation are determined by the
+  // callgraph.
+  const auto *predecessors = getOrCreateFor<PredecessorState>(
+      getProgramPointAfter(call), getProgramPointAfter(call));
+  // If not all return sites are known, then conservatively assume we can't
+  // reason about the data-flow.
+  if (!predecessors->allPredecessorsKnown()) {
+    setAllToEntryStates(resultLattices);
+    return success();
+  }
+  for (Operation *predecessor : predecessors->getKnownPredecessors())
+    for (auto &&[operand, resLattice] :
+         llvm::zip(predecessor->getOperands(), resultLattices))
+      join(resLattice,
+           *getLatticeElementFor(getProgramPointAfter(call), operand));
+  return success();
+}
+
+void AbstractSparseForwardDataFlowAnalysis::visitCallableOperation(
+    CallableOpInterface callable,
+    ArrayRef<AbstractSparseLattice *> argLattices) {
+  Block *entryBlock = &callable.getCallableRegion()->front();
+  const auto *callsites = getOrCreateFor<PredecessorState>(
+      getProgramPointBefore(entryBlock), getProgramPointAfter(callable));
+  // If not all callsites are known, conservatively mark all lattices as
+  // having reached their pessimistic fixpoints.
+  if (!callsites->allPredecessorsKnown() ||
+      !getSolverConfig().isInterprocedural()) {
+    return setAllToEntryStates(argLattices);
+  }
+  for (Operation *callsite : callsites->getKnownPredecessors()) {
+    auto call = cast<CallOpInterface>(callsite);
+    for (auto it : llvm::zip(call.getArgOperands(), argLattices))
+      join(std::get<1>(it),
+           *getLatticeElementFor(getProgramPointBefore(entryBlock),
+                                 std::get<0>(it)));
+  }
+}
+
 void AbstractSparseForwardDataFlowAnalysis::visitRegionSuccessors(
     ProgramPoint *point, RegionBranchOpInterface branch,
     RegionBranchPoint successor, ArrayRef<AbstractSparseLattice *> lattices) {
@@ -512,31 +523,34 @@ AbstractSparseBackwardDataFlowAnalysis::visitOperation(Operation *op) {
   if (op->hasTrait<OpTrait::ReturnLike>()) {
     // Going backwards, the operands of the return are derived from the
     // results of all CallOps calling this CallableOp.
-    if (auto callable = dyn_cast<CallableOpInterface>(op->getParentOp())) {
-      const PredecessorState *callsites = getOrCreateFor<PredecessorState>(
-          getProgramPointAfter(op), getProgramPointAfter(callable));
-      if (callsites->allPredecessorsKnown()) {
-        for (Operation *call : callsites->getKnownPredecessors()) {
-          SmallVector<const AbstractSparseLattice *> callResultLattices =
-              getLatticeElementsFor(getProgramPointAfter(op),
-                                    call->getResults());
-          for (auto [op, result] :
-               llvm::zip(operandLattices, callResultLattices))
-            meet(op, *result);
-        }
-      } else {
-        // If we don't know all the callers, we can't know where the
-        // returned values go. Note that, in particular, this will trigger
-        // for the return ops of any public functions.
-        setAllToExitStates(operandLattices);
-      }
-      return success();
-    }
+    if (auto callable = dyn_cast<CallableOpInterface>(op->getParentOp()))
+      return visitCallableOperation(op, callable, operandLattices);
   }
 
   return visitOperationImpl(op, operandLattices, resultLattices);
 }
 
+LogicalResult AbstractSparseBackwardDataFlowAnalysis::visitCallableOperation(
+    Operation *op, CallableOpInterface callable,
+    ArrayRef<AbstractSparseLattice *> operandLattices) {
+  const PredecessorState *callsites = getOrCreateFor<PredecessorState>(
+      getProgramPointAfter(op), getProgramPointAfter(callable));
+  if (callsites->allPredecessorsKnown()) {
+    for (Operation *call : callsites->getKnownPredecessors()) {
+      SmallVector<const AbstractSparseLattice *> callResultLattices =
+          getLatticeElementsFor(getProgramPointAfter(op), call->getResults());
+      for (auto [op, result] : llvm::zip(operandLattices, callResultLattices))
+        meet(op, *result);
+    }
+  } else {
+    // If we don't know all the callers, we can't know where the
+    // returned values go. Note that, in particular, this will trigger
+    // for the return ops of any public functions.
+    setAllToExitStates(operandLattices);
+  }
+  return success();
+}
+
 void AbstractSparseBackwardDataFlowAnalysis::visitRegionSuccessors(
     RegionBranchOpInterface branch,
     ArrayRef<AbstractSparseLattice *> operandLattices) {

@gysit gysit requested review from Mogball, ftynse and makslevental June 3, 2025 06:50
Copy link
Contributor

@makslevental makslevental left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

EDIT: looking again - this is a good refactor that simplifies/explains visitOperation well.

@gysit gysit merged commit 5a531b1 into llvm:main Jun 4, 2025
13 checks passed
@VadimCurca VadimCurca deleted the vadimc/data_flow_analysis_costumization branch June 4, 2025 12:15
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants