diff --git a/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h index 6a1335bab8bf6..088b6cd7d698f 100644 --- a/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h +++ b/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h @@ -27,8 +27,9 @@ namespace dataflow { // CallControlFlowAction //===----------------------------------------------------------------------===// -/// Indicates whether the control enters or exits the callee. -enum class CallControlFlowAction { EnterCallee, ExitCallee }; +/// Indicates whether the control enters, exits, or skips over the callee (in +/// the case of external functions). +enum class CallControlFlowAction { EnterCallee, ExitCallee, ExternalCallee }; //===----------------------------------------------------------------------===// // AbstractDenseLattice @@ -131,14 +132,21 @@ class AbstractDenseForwardDataFlowAnalysis : public DataFlowAnalysis { /// Propagate the dense lattice forward along the call control flow edge, /// which can be either entering or exiting the callee. Default implementation - /// just meets the states, meaning that operations implementing - /// `CallOpInterface` don't have any effect on the lattice that isn't already - /// expressed by the interface itself. + /// for enter and exit callee actions just meets the states, meaning that + /// operations implementing `CallOpInterface` don't have any effect on the + /// lattice that isn't already expressed by the interface itself. Default + /// implementation for the external callee action additionally sets the + /// "after" lattice to the entry state. virtual void visitCallControlFlowTransfer(CallOpInterface call, CallControlFlowAction action, const AbstractDenseLattice &before, AbstractDenseLattice *after) { join(after, before); + // Note that `setToEntryState` may be a "partial fixpoint" for some + // lattices, e.g., lattices that are lists of maps of other lattices will + // only set fixpoint for "known" lattices. + if (action == CallControlFlowAction::ExternalCallee) + setToEntryState(after); } /// Visit a program point within a region branch operation with predecessors @@ -155,7 +163,9 @@ class AbstractDenseForwardDataFlowAnalysis : public DataFlowAnalysis { /// Visit an operation for which the data flow is described by the /// `CallOpInterface`. - void visitCallOperation(CallOpInterface call, AbstractDenseLattice *after); + void visitCallOperation(CallOpInterface call, + const AbstractDenseLattice &before, + AbstractDenseLattice *after); }; //===----------------------------------------------------------------------===// @@ -361,14 +371,22 @@ class AbstractDenseBackwardDataFlowAnalysis : public DataFlowAnalysis { /// Propagate the dense lattice backwards along the call control flow edge, /// which can be either entering or exiting the callee. Default implementation - /// just meets the states, meaning that operations implementing - /// `CallOpInterface` don't have any effect on hte lattice that isn't already - /// expressed by the interface itself. + /// for enter and exit callee action just meets the states, meaning that + /// operations implementing `CallOpInterface` don't have any effect on the + /// lattice that isn't already expressed by the interface itself. Default + /// implementation for external callee action additional sets the result to + /// the exit (fixpoint) state. virtual void visitCallControlFlowTransfer(CallOpInterface call, CallControlFlowAction action, const AbstractDenseLattice &after, AbstractDenseLattice *before) { meet(before, after); + + // Note that `setToExitState` may be a "partial fixpoint" for some lattices, + // e.g., lattices that are lists of maps of other lattices will only + // set fixpoint for "known" lattices. + if (action == CallControlFlowAction::ExternalCallee) + setToExitState(before); } private: @@ -394,7 +412,9 @@ class AbstractDenseBackwardDataFlowAnalysis : public DataFlowAnalysis { /// otherwise, /// - meet that state with the state before the call-like op, or use the /// custom logic if overridden by concrete analyses. - void visitCallOperation(CallOpInterface call, AbstractDenseLattice *before); + void visitCallOperation(CallOpInterface call, + const AbstractDenseLattice &after, + AbstractDenseLattice *before); /// Symbol table for call-level control flow. SymbolTableCollection &symbolTable; diff --git a/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h index 5a9a36159b56c..b65ac8bb1dec2 100644 --- a/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h +++ b/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h @@ -17,6 +17,7 @@ #include "mlir/Analysis/DataFlowFramework.h" #include "mlir/IR/SymbolTable.h" +#include "mlir/Interfaces/CallInterfaces.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" #include "llvm/ADT/SmallPtrSet.h" @@ -199,6 +200,12 @@ class AbstractSparseForwardDataFlowAnalysis : public DataFlowAnalysis { ArrayRef operandLattices, ArrayRef resultLattices) = 0; + /// The transfer function for calls to external functions. + virtual void visitExternalCallImpl( + CallOpInterface call, + ArrayRef argumentLattices, + ArrayRef resultLattices) = 0; + /// Given an operation with region control-flow, the lattices of the operands, /// and a region successor, compute the lattice values for block arguments /// that are not accounted for by the branching control flow (ex. the bounds @@ -271,6 +278,14 @@ class SparseForwardDataFlowAnalysis virtual void visitOperation(Operation *op, ArrayRef operands, ArrayRef results) = 0; + /// Visit a call operation to an externally defined function given the + /// lattices of its arguments. + virtual void visitExternalCall(CallOpInterface call, + ArrayRef argumentLattices, + ArrayRef resultLattices) { + setAllToEntryStates(resultLattices); + } + /// Given an operation with possible region control-flow, the lattices of the /// operands, and a region successor, compute the lattice values for block /// arguments that are not accounted for by the branching control flow (ex. @@ -321,6 +336,17 @@ class SparseForwardDataFlowAnalysis {reinterpret_cast(resultLattices.begin()), resultLattices.size()}); } + void visitExternalCallImpl( + CallOpInterface call, + ArrayRef argumentLattices, + ArrayRef resultLattices) override { + visitExternalCall( + call, + {reinterpret_cast(argumentLattices.begin()), + argumentLattices.size()}, + {reinterpret_cast(resultLattices.begin()), + resultLattices.size()}); + } void visitNonControlFlowArgumentsImpl( Operation *op, const RegionSuccessor &successor, ArrayRef argLattices, @@ -363,6 +389,11 @@ class AbstractSparseBackwardDataFlowAnalysis : public DataFlowAnalysis { Operation *op, ArrayRef operandLattices, ArrayRef resultLattices) = 0; + /// The transfer function for calls to external functions. + virtual void visitExternalCallImpl( + CallOpInterface call, ArrayRef operandLattices, + ArrayRef resultLattices) = 0; + // Visit operands on branch instructions that are not forwarded. virtual void visitBranchOperand(OpOperand &operand) = 0; @@ -444,6 +475,19 @@ class SparseBackwardDataFlowAnalysis virtual void visitOperation(Operation *op, ArrayRef operands, ArrayRef results) = 0; + /// Visit a call to an external function. This function is expected to set + /// lattice values of the call operands. By default, calls `visitCallOperand` + /// for all operands. + virtual void visitExternalCall(CallOpInterface call, + ArrayRef argumentLattices, + ArrayRef resultLattices) { + (void)argumentLattices; + (void)resultLattices; + for (OpOperand &operand : call->getOpOperands()) { + visitCallOperand(operand); + } + }; + protected: /// Get the lattice element for a value. StateT *getLatticeElement(Value value) override { @@ -474,6 +518,17 @@ class SparseBackwardDataFlowAnalysis {reinterpret_cast(resultLattices.begin()), resultLattices.size()}); } + + void visitExternalCallImpl( + CallOpInterface call, ArrayRef operandLattices, + ArrayRef resultLattices) override { + visitExternalCall( + call, + {reinterpret_cast(operandLattices.begin()), + operandLattices.size()}, + {reinterpret_cast(resultLattices.begin()), + resultLattices.size()}); + } }; } // end namespace dataflow diff --git a/mlir/include/mlir/Analysis/DataFlowFramework.h b/mlir/include/mlir/Analysis/DataFlowFramework.h index c27615b52a12b..541cdb1e237c1 100644 --- a/mlir/include/mlir/Analysis/DataFlowFramework.h +++ b/mlir/include/mlir/Analysis/DataFlowFramework.h @@ -175,6 +175,32 @@ struct ProgramPoint /// Forward declaration of the data-flow analysis class. class DataFlowAnalysis; +//===----------------------------------------------------------------------===// +// DataFlowConfig +//===----------------------------------------------------------------------===// + +/// Configuration class for data flow solver and child analyses. Follows the +/// fluent API pattern. +class DataFlowConfig { +public: + DataFlowConfig() = default; + + /// Set whether the solver should operate interpocedurally, i.e. enter the + /// callee body when available. Interprocedural analyses may be more precise, + /// but also more expensive as more states need to be computed and the + /// fixpoint convergence takes longer. + DataFlowConfig &setInterprocedural(bool enable) { + interprocedural = enable; + return *this; + } + + /// Return `true` if the solver operates interprocedurally, `false` otherwise. + bool isInterprocedural() const { return interprocedural; } + +private: + bool interprocedural = true; +}; + //===----------------------------------------------------------------------===// // DataFlowSolver //===----------------------------------------------------------------------===// @@ -195,6 +221,9 @@ class DataFlowAnalysis; /// TODO: Optimize the internal implementation of the solver. class DataFlowSolver { public: + explicit DataFlowSolver(const DataFlowConfig &config = DataFlowConfig()) + : config(config) {} + /// Load an analysis into the solver. Return the analysis instance. template AnalysisT *load(Args &&...args); @@ -236,7 +265,13 @@ class DataFlowSolver { /// dependent work items to the back of the queue. void propagateIfChanged(AnalysisState *state, ChangeResult changed); + /// Get the configuration of the solver. + const DataFlowConfig &getConfig() const { return config; } + private: + /// Configuration of the dataflow solver. + DataFlowConfig config; + /// The solver's work queue. Work items can be inserted to the front of the /// queue to be processed greedily, speeding up computations that otherwise /// quickly degenerate to quadratic due to propagation of state updates. @@ -423,6 +458,9 @@ class DataFlowAnalysis { return state; } + /// Return the configuration of the solver used for this analysis. + const DataFlowConfig &getSolverConfig() const { return solver.getConfig(); } + #if LLVM_ENABLE_ABI_BREAKING_CHECKS /// When compiling with debugging, keep a name for the analyis. StringRef debugName; diff --git a/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp b/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp index a6c9f7d7da225..08d89d6db788c 100644 --- a/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp +++ b/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp @@ -54,12 +54,22 @@ LogicalResult AbstractDenseForwardDataFlowAnalysis::visit(ProgramPoint point) { } void AbstractDenseForwardDataFlowAnalysis::visitCallOperation( - CallOpInterface call, AbstractDenseLattice *after) { + CallOpInterface call, const AbstractDenseLattice &before, + AbstractDenseLattice *after) { + // Allow for customizing the behavior of calls to external symbols, including + // when the analysis is explicitly marked as non-interprocedural. + auto callable = + dyn_cast_if_present(call.resolveCallable()); + if (!getSolverConfig().isInterprocedural() || + (callable && !callable.getCallableRegion())) { + return visitCallControlFlowTransfer( + call, CallControlFlowAction::ExternalCallee, before, after); + } const auto *predecessors = getOrCreateFor(call.getOperation(), call); - // If not all return sites are known, then conservatively assume we can't - // reason about the data-flow. + // Otherwise, if not all return sites are known, then conservatively assume we + // can't reason about the data-flow. if (!predecessors->allPredecessorsKnown()) return setToEntryState(after); @@ -108,7 +118,7 @@ void AbstractDenseForwardDataFlowAnalysis::processOperation(Operation *op) { // If this is a call operation, then join its lattices across known return // sites. if (auto call = dyn_cast(op)) - return visitCallOperation(call, after); + return visitCallOperation(call, *before, after); // Invoke the operation transfer function. visitOperationImpl(op, *before, after); @@ -130,8 +140,10 @@ void AbstractDenseForwardDataFlowAnalysis::visitBlock(Block *block) { if (callable && callable.getCallableRegion() == block->getParent()) { const auto *callsites = getOrCreateFor(block, callable); // If not all callsites are known, conservatively mark all lattices as - // having reached their pessimistic fixpoints. - if (!callsites->allPredecessorsKnown()) + // having reached their pessimistic fixpoints. Do the same if + // interprocedural analysis is not enabled. + if (!callsites->allPredecessorsKnown() || + !getSolverConfig().isInterprocedural()) return setToEntryState(after); for (Operation *callsite : callsites->getKnownPredecessors()) { // Get the dense lattice before the callsite. @@ -267,18 +279,20 @@ LogicalResult AbstractDenseBackwardDataFlowAnalysis::visit(ProgramPoint point) { } void AbstractDenseBackwardDataFlowAnalysis::visitCallOperation( - CallOpInterface call, AbstractDenseLattice *before) { + CallOpInterface call, const AbstractDenseLattice &after, + AbstractDenseLattice *before) { // Find the callee. Operation *callee = call.resolveCallable(&symbolTable); auto callable = dyn_cast_or_null(callee); if (!callable) return setToExitState(before); - // No region means the callee is only declared in this module and we shouldn't - // assume anything about it. + // No region means the callee is only declared in this module. Region *region = callable.getCallableRegion(); - if (!region || region->empty()) - return setToExitState(before); + if (!region || region->empty() || !getSolverConfig().isInterprocedural()) { + return visitCallControlFlowTransfer( + call, CallControlFlowAction::ExternalCallee, after, before); + } // Call-level control flow specifies the data flow here. // @@ -324,7 +338,7 @@ void AbstractDenseBackwardDataFlowAnalysis::processOperation(Operation *op) { return visitRegionBranchOperation(op, branch, RegionBranchPoint::parent(), before); if (auto call = dyn_cast(op)) - return visitCallOperation(call, before); + return visitCallOperation(call, *after, before); // Invoke the operation transfer function. visitOperationImpl(op, *after, before); @@ -359,8 +373,10 @@ void AbstractDenseBackwardDataFlowAnalysis::visitBlock(Block *block) { const auto *callsites = getOrCreateFor(block, callable); // If not all call sites are known, conservative mark all lattices as // having reached their pessimistic fix points. - if (!callsites->allPredecessorsKnown()) + if (!callsites->allPredecessorsKnown() || + !getSolverConfig().isInterprocedural()) { return setToExitState(before); + } for (Operation *callsite : callsites->getKnownPredecessors()) { const AbstractDenseLattice *after; diff --git a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp index 9f544d656df92..b47bba16fd902 100644 --- a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp +++ b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp @@ -116,8 +116,27 @@ void AbstractSparseForwardDataFlowAnalysis::visitOperation(Operation *op) { resultLattices); } - // The results of a call operation are determined by the callgraph. + // Grab the lattice elements of the operands. + SmallVector operandLattices; + operandLattices.reserve(op->getNumOperands()); + for (Value operand : op->getOperands()) { + AbstractSparseLattice *operandLattice = getLatticeElement(operand); + operandLattice->useDefSubscribe(this); + operandLattices.push_back(operandLattice); + } + if (auto call = dyn_cast(op)) { + // If the call operation is to an external function, attempt to infer the + // results from the call arguments. + auto callable = + dyn_cast_if_present(call.resolveCallable()); + if (!getSolverConfig().isInterprocedural() || + (callable && !callable.getCallableRegion())) { + return visitExternalCallImpl(call, operandLattices, resultLattices); + } + + // Otherwise, the results of a call operation are determined by the + // callgraph. const auto *predecessors = getOrCreateFor(op, call); // If not all return sites are known, then conservatively assume we can't // reason about the data-flow. @@ -129,15 +148,6 @@ void AbstractSparseForwardDataFlowAnalysis::visitOperation(Operation *op) { return; } - // Grab the lattice elements of the operands. - SmallVector operandLattices; - operandLattices.reserve(op->getNumOperands()); - for (Value operand : op->getOperands()) { - AbstractSparseLattice *operandLattice = getLatticeElement(operand); - operandLattice->useDefSubscribe(this); - operandLattices.push_back(operandLattice); - } - // Invoke the operation transfer function. visitOperationImpl(op, operandLattices, resultLattices); } @@ -168,8 +178,10 @@ void AbstractSparseForwardDataFlowAnalysis::visitBlock(Block *block) { const auto *callsites = getOrCreateFor(block, callable); // If not all callsites are known, conservatively mark all lattices as // having reached their pessimistic fixpoints. - if (!callsites->allPredecessorsKnown()) + if (!callsites->allPredecessorsKnown() || + !getSolverConfig().isInterprocedural()) { return setAllToEntryStates(argLattices); + } for (Operation *callsite : callsites->getKnownPredecessors()) { auto call = cast(callsite); for (auto it : llvm::zip(call.getArgOperands(), argLattices)) @@ -433,19 +445,26 @@ void AbstractSparseBackwardDataFlowAnalysis::visitOperation(Operation *op) { // stored in `unaccounted`. BitVector unaccounted(op->getNumOperands(), true); + // If the call invokes an external function (or a function treated as + // external due to config), defer to the corresponding extension hook. + // By default, it just does `visitCallOperand` for all operands. OperandRange argOperands = call.getArgOperands(); MutableArrayRef argOpOperands = operandsToOpOperands(argOperands); Region *region = callable.getCallableRegion(); - if (region && !region->empty()) { - Block &block = region->front(); - for (auto [blockArg, argOpOperand] : - llvm::zip(block.getArguments(), argOpOperands)) { - meet(getLatticeElement(argOpOperand.get()), - *getLatticeElementFor(op, blockArg)); - unaccounted.reset(argOpOperand.getOperandNumber()); - } + if (!region || region->empty() || !getSolverConfig().isInterprocedural()) + return visitExternalCallImpl(call, operandLattices, resultLattices); + + // Otherwise, propagate information from the entry point of the function + // back to operands whenever possible. + Block &block = region->front(); + for (auto [blockArg, argOpOperand] : + llvm::zip(block.getArguments(), argOpOperands)) { + meet(getLatticeElement(argOpOperand.get()), + *getLatticeElementFor(op, blockArg)); + unaccounted.reset(argOpOperand.getOperandNumber()); } + // Handle the operands of the call op that aren't forwarded to any // arguments. for (int index : unaccounted.set_bits()) { diff --git a/mlir/test/Analysis/DataFlow/test-last-modified-callgraph.mlir b/mlir/test/Analysis/DataFlow/test-last-modified-callgraph.mlir index 709d787bb306b..a5eba43ac68ab 100644 --- a/mlir/test/Analysis/DataFlow/test-last-modified-callgraph.mlir +++ b/mlir/test/Analysis/DataFlow/test-last-modified-callgraph.mlir @@ -1,8 +1,32 @@ -// RUN: mlir-opt -test-last-modified --split-input-file %s 2>&1 | FileCheck %s +// RUN: mlir-opt -test-last-modified --split-input-file %s 2>&1 |\ +// RUN: FileCheck %s --check-prefixes=CHECK,IP,IP_ONLY +// RUN: mlir-opt -test-last-modified='assume-func-writes=true' \ +// RUN: --split-input-file %s 2>&1 |\ +// RUN: FileCheck %s --check-prefixes=CHECK,IP,IP_AW +// RUN: mlir-opt -test-last-modified='interprocedural=false' \ +// RUN: --split-input-file %s 2>&1 |\ +// RUN: FileCheck %s --check-prefixes=CHECK,LOCAL +// RUN: mlir-opt \ +// RUN: -test-last-modified='interprocedural=false assume-func-writes=true' \ +// RUN: --split-input-file %s 2>&1 |\ +// RUN: FileCheck %s --check-prefixes=CHECK,LC_AW + +// Check prefixes are as follows: +// 'check': common for all runs; +// 'ip': interprocedural runs; +// 'ip_aw': interpocedural runs assuming calls to external functions write to +// all arguments; +// 'ip_only': interprocedural runs not assuming calls writing; +// 'local': local (non-interprocedural) analysis not assuming calls writing; +// 'lc_aw': local analysis assuming external calls writing to all arguments. // CHECK-LABEL: test_tag: test_callsite -// CHECK: operand #0 -// CHECK-NEXT: - a +// IP: operand #0 +// IP-NEXT: - a +// LOCAL: operand #0 +// LOCAL-NEXT: - +// LC_AW: operand #0 +// LC_AW-NEXT: - func.func private @single_callsite_fn(%ptr: memref) -> memref { return {tag = "test_callsite"} %ptr : memref } @@ -16,8 +40,12 @@ func.func @test_callsite() { } // CHECK-LABEL: test_tag: test_return_site -// CHECK: operand #0 -// CHECK-NEXT: - b +// IP: operand #0 +// IP-NEXT: - b +// LOCAL: operand #0 +// LOCAL-NEXT: - +// LC_AW: operand #0 +// LC_AW-NEXT: - func.func private @single_return_site_fn(%ptr: memref) -> memref { %c0 = arith.constant 0 : i32 memref.store %c0, %ptr[] {tag_name = "b"} : memref @@ -25,9 +53,13 @@ func.func private @single_return_site_fn(%ptr: memref) -> memref { } // CHECK-LABEL: test_tag: test_multiple_callsites -// CHECK: operand #0 -// CHECK-NEXT: write0 -// CHECK-NEXT: write1 +// IP: operand #0 +// IP-NEXT: write0 +// IP-NEXT: write1 +// LOCAL: operand #0 +// LOCAL-NEXT: - +// LC_AW: operand #0 +// LC_AW-NEXT: - func.func @test_return_site(%ptr: memref) -> memref { %0 = func.call @single_return_site_fn(%ptr) : (memref) -> memref return {tag = "test_return_site"} %0 : memref @@ -46,9 +78,13 @@ func.func @test_multiple_callsites(%a: i32, %ptr: memref) -> memref { } // CHECK-LABEL: test_tag: test_multiple_return_sites -// CHECK: operand #0 -// CHECK-NEXT: return0 -// CHECK-NEXT: return1 +// IP: operand #0 +// IP-NEXT: return0 +// IP-NEXT: return1 +// LOCAL: operand #0 +// LOCAL-NEXT: - +// LC_AW: operand #0 +// LC_AW-NEXT: - func.func private @multiple_return_site_fn(%cond: i1, %a: i32, %ptr: memref) -> memref { cf.cond_br %cond, ^a, ^b @@ -69,8 +105,12 @@ func.func @test_multiple_return_sites(%cond: i1, %a: i32, %ptr: memref) -> // ----- // CHECK-LABEL: test_tag: after_call -// CHECK: operand #0 -// CHECK-NEXT: - write0 +// IP: operand #0 +// IP-NEXT: - write0 +// LOCAL: operand #0 +// LOCAL-NEXT: - +// LC_AW: operand #0 +// LC_AW-NEXT: - func.call func.func private @void_return(%ptr: memref) { return } @@ -98,17 +138,29 @@ func.func private @callee(%arg0: memref) -> memref { // "pre" -> "call" -> "callee" -> "post" // CHECK-LABEL: test_tag: call_and_store_before::enter_callee: -// CHECK: operand #0 -// CHECK: - call +// IP: operand #0 +// IP: - call +// LOCAL: operand #0 +// LOCAL: - +// LC_AW: operand #0 +// LC_AW: - + // CHECK: test_tag: exit_callee: // CHECK: operand #0 // CHECK: - callee + // CHECK: test_tag: before_call: // CHECK: operand #0 // CHECK: - pre + // CHECK: test_tag: after_call: -// CHECK: operand #0 -// CHECK: - callee +// IP: operand #0 +// IP: - callee +// LOCAL: operand #0 +// LOCAL: - +// LC_AW: operand #0 +// LC_AW: - call + // CHECK: test_tag: return: // CHECK: operand #0 // CHECK: - post @@ -138,17 +190,29 @@ func.func private @callee(%arg0: memref) -> memref { // "pre" -> "callee" -> "call" -> "post" // CHECK-LABEL: test_tag: call_and_store_after::enter_callee: -// CHECK: operand #0 -// CHECK: - pre +// IP: operand #0 +// IP: - pre +// LOCAL: operand #0 +// LOCAL: - +// LC_AW: operand #0 +// LC_AW: - + // CHECK: test_tag: exit_callee: // CHECK: operand #0 // CHECK: - callee + // CHECK: test_tag: before_call: // CHECK: operand #0 // CHECK: - pre -// CHECK: test_tag: after_call: -// CHECK: operand #0 -// CHECK: - call + +// CHECK: test_tag: after_call: +// IP: operand #0 +// IP: - call +// LOCAL: operand #0 +// LOCAL: - +// LC_AW: operand #0 +// LC_AW: - call + // CHECK: test_tag: return: // CHECK: operand #0 // CHECK: - post @@ -162,3 +226,20 @@ func.func @call_and_store_after(%arg0: memref) -> memref { memref.store %1, %arg0[] {tag_name = "post"} : memref return {tag = "return"} %arg0 : memref } + +// ----- + +func.func private @void_return(%ptr: memref) + +// CHECK-LABEL: test_tag: after_opaque_call: +// CHECK: operand #0 +// IP_ONLY: - +// IP_AW: - func.call +func.func @test_opaque_call_return() { + %ptr = memref.alloc() : memref + %c0 = arith.constant 0 : i32 + memref.store %c0, %ptr[] {tag_name = "write0"} : memref + func.call @void_return(%ptr) : (memref) -> () + memref.load %ptr[] {tag = "after_opaque_call"} : memref + return +} diff --git a/mlir/test/Analysis/DataFlow/test-next-access.mlir b/mlir/test/Analysis/DataFlow/test-next-access.mlir index 313a75c171d01..de0788fb6a176 100644 --- a/mlir/test/Analysis/DataFlow/test-next-access.mlir +++ b/mlir/test/Analysis/DataFlow/test-next-access.mlir @@ -1,4 +1,22 @@ -// RUN: mlir-opt %s --test-next-access --split-input-file | FileCheck %s +// RUN: mlir-opt %s --test-next-access --split-input-file |\ +// RUN: FileCheck %s --check-prefixes=CHECK,IP +// RUN: mlir-opt %s --test-next-access='interprocedural=false' \ +// RUN: --split-input-file |\ +// RUN: FileCheck %s --check-prefixes=CHECK,LOCAL +// RUN: mlir-opt %s --test-next-access='assume-func-reads=true' \ +// RUN: --split-input-file |\ +// RUN: FileCheck %s --check-prefixes=CHECK,IP_AR +// RUN: mlir-opt %s \ +// RUN: --test-next-access='interprocedural=false assume-func-reads=true' \ +// RUN: --split-input-file | FileCheck %s --check-prefixes=CHECK,LC_AR + +// Check prefixes are as follows: +// 'check': common for all runs; +// 'ip_ar': interpocedural runs assuming calls to external functions read +// all arguments; +// 'ip': interprocedural runs not assuming function calls reading; +// 'local': local (non-interprocedural) analysis not assuming calls reading; +// 'lc_ar': local analysis assuming external calls reading all arguments. // CHECK-LABEL: @trivial func.func @trivial(%arg0: memref, %arg1: f32) -> f32 { @@ -252,8 +270,10 @@ func.func @known_conditional_cf(%arg0: memref) { // ----- func.func private @callee1(%arg0: memref) { - // CHECK: name = "callee1" - // CHECK-SAME: next_access = {{\[}}["post"]] + // IP: name = "callee1" + // IP-SAME: next_access = {{\[}}["post"]] + // LOCAL: name = "callee1" + // LOCAL-SAME: next_access = ["unknown"] memref.load %arg0[] {name = "callee1"} : memref return } @@ -267,10 +287,14 @@ func.func private @callee2(%arg0: memref) { // CHECK-LABEL: @simple_call func.func @simple_call(%arg0: memref) { - // CHECK: name = "caller" - // CHECK-SAME: next_access = {{\[}}["callee1"]] + // IP: name = "caller" + // IP-SAME: next_access = {{\[}}["callee1"]] + // LOCAL: name = "caller" + // LOCAL-SAME: next_access = ["unknown"] + // LC_AR: name = "caller" + // LC_AR-SAME: next_access = {{\[}}["call"]] memref.load %arg0[] {name = "caller"} : memref - func.call @callee1(%arg0) : (memref) -> () + func.call @callee1(%arg0) {name = "call"} : (memref) -> () memref.load %arg0[] {name = "post"} : memref return } @@ -279,10 +303,14 @@ func.func @simple_call(%arg0: memref) { // CHECK-LABEL: @infinite_recursive_call func.func @infinite_recursive_call(%arg0: memref) { - // CHECK: name = "pre" - // CHECK-SAME: next_access = {{\[}}["pre"]] + // IP: name = "pre" + // IP-SAME: next_access = {{\[}}["pre"]] + // LOCAL: name = "pre" + // LOCAL-SAME: next_access = ["unknown"] + // LC_AR: name = "pre" + // LC_AR-SAME: next_access = {{\[}}["call"]] memref.load %arg0[] {name = "pre"} : memref - func.call @infinite_recursive_call(%arg0) : (memref) -> () + func.call @infinite_recursive_call(%arg0) {name = "call"} : (memref) -> () memref.load %arg0[] {name = "post"} : memref return } @@ -291,11 +319,15 @@ func.func @infinite_recursive_call(%arg0: memref) { // CHECK-LABEL: @recursive_call func.func @recursive_call(%arg0: memref, %cond: i1) { - // CHECK: name = "pre" - // CHECK-SAME: next_access = {{\[}}["post", "pre"]] + // IP: name = "pre" + // IP-SAME: next_access = {{\[}}["post", "pre"]] + // LOCAL: name = "pre" + // LOCAL-SAME: next_access = ["unknown"] + // LC_AR: name = "pre" + // LC_AR-SAME: next_access = {{\[}}["post", "call"]] memref.load %arg0[] {name = "pre"} : memref scf.if %cond { - func.call @recursive_call(%arg0, %cond) : (memref, i1) -> () + func.call @recursive_call(%arg0, %cond) {name = "call"} : (memref, i1) -> () } memref.load %arg0[] {name = "post"} : memref return @@ -305,12 +337,16 @@ func.func @recursive_call(%arg0: memref, %cond: i1) { // CHECK-LABEL: @recursive_call_cf func.func @recursive_call_cf(%arg0: memref, %cond: i1) { - // CHECK: name = "pre" - // CHECK-SAME: next_access = {{\[}}["pre", "post"]] + // IP: name = "pre" + // IP-SAME: next_access = {{\[}}["pre", "post"]] + // LOCAL: name = "pre" + // LOCAL-SAME: next_access = ["unknown"] + // LC_AR: name = "pre" + // LC_AR-SAME: next_access = {{\[}}["call", "post"]] %0 = memref.load %arg0[] {name = "pre"} : memref cf.cond_br %cond, ^bb1, ^bb2 ^bb1: - call @recursive_call_cf(%arg0, %cond) : (memref, i1) -> () + call @recursive_call_cf(%arg0, %cond) {name = "call"} : (memref, i1) -> () cf.br ^bb2 ^bb2: %2 = memref.load %arg0[] {name = "post"} : memref @@ -320,27 +356,35 @@ func.func @recursive_call_cf(%arg0: memref, %cond: i1) { // ----- func.func private @callee1(%arg0: memref) { - // CHECK: name = "callee1" - // CHECK-SAME: next_access = {{\[}}["post"]] + // IP: name = "callee1" + // IP-SAME: next_access = {{\[}}["post"]] + // LOCAL: name = "callee1" + // LOCAL-SAME: next_access = ["unknown"] memref.load %arg0[] {name = "callee1"} : memref return } func.func private @callee2(%arg0: memref) { - // CHECK: name = "callee2" - // CHECK-SAME: next_access = {{\[}}["post"]] + // IP: name = "callee2" + // IP-SAME: next_access = {{\[}}["post"]] + // LOCAL: name = "callee2" + // LOCAL-SAME: next_access = ["unknown"] memref.load %arg0[] {name = "callee2"} : memref return } func.func @conditonal_call(%arg0: memref, %cond: i1) { - // CHECK: name = "pre" - // CHECK-SAME: next_access = {{\[}}["callee1", "callee2"]] + // IP: name = "pre" + // IP-SAME: next_access = {{\[}}["callee1", "callee2"]] + // LOCAL: name = "pre" + // LOCAL-SAME: next_access = ["unknown"] + // LC_AR: name = "pre" + // LC_AR-SAME: next_access = {{\[}}["call1", "call2"]] memref.load %arg0[] {name = "pre"} : memref scf.if %cond { - func.call @callee1(%arg0) : (memref) -> () + func.call @callee1(%arg0) {name = "call1"} : (memref) -> () } else { - func.call @callee2(%arg0) : (memref) -> () + func.call @callee2(%arg0) {name = "call2"} : (memref) -> () } memref.load %arg0[] {name = "post"} : memref return @@ -354,16 +398,22 @@ func.func @conditonal_call(%arg0: memref, %cond: i1) { // "caller" -> "call" -> "callee" -> "post" func.func private @callee(%arg0: memref) { - // CHECK: name = "callee" - // CHECK-SAME-LITERAL: next_access = [["post"]] + // IP: name = "callee" + // IP-SAME-LITERAL: next_access = [["post"]] + // LOCAL: name = "callee" + // LOCAL-SAME: next_access = ["unknown"] memref.load %arg0[] {name = "callee"} : memref return } // CHECK-LABEL: @call_and_store_before func.func @call_and_store_before(%arg0: memref) { - // CHECK: name = "caller" - // CHECK-SAME-LITERAL: next_access = [["call"]] + // IP: name = "caller" + // IP-SAME-LITERAL: next_access = [["call"]] + // LOCAL: name = "caller" + // LOCAL-SAME: next_access = ["unknown"] + // LC_AR: name = "caller" + // LC_AR-SAME: next_access = {{\[}}["call"]] memref.load %arg0[] {name = "caller"} : memref // Note that the access after the entire call is "post". // CHECK: name = "call" @@ -382,20 +432,26 @@ func.func @call_and_store_before(%arg0: memref) { // "caller" -> "callee" -> "call" -> "post" func.func private @callee(%arg0: memref) { - // CHECK: name = "callee" - // CHECK-SAME-LITERAL: next_access = [["call"]] + // IP: name = "callee" + // IP-SAME-LITERAL: next_access = [["call"]] + // LOCAL: name = "callee" + // LOCAL-SAME: next_access = ["unknown"] memref.load %arg0[] {name = "callee"} : memref return } // CHECK-LABEL: @call_and_store_after func.func @call_and_store_after(%arg0: memref) { - // CHECK: name = "caller" - // CHECK-SAME-LITERAL: next_access = [["callee"]] + // IP: name = "caller" + // IP-SAME-LITERAL: next_access = [["callee"]] + // LOCAL: name = "caller" + // LOCAL-SAME: next_access = ["unknown"] + // LC_AR: name = "caller" + // LC_AR-SAME: next_access = {{\[}}["call"]] memref.load %arg0[] {name = "caller"} : memref // CHECK: name = "call" // CHECK-SAME-LITERAL: next_access = [["post"], ["post"]] - test.call_and_store @callee(%arg0), %arg0 {name = "call", store_before_call = true} : (memref, memref) -> () + test.call_and_store @callee(%arg0), %arg0 {name = "call", store_before_call = false} : (memref, memref) -> () // CHECK: name = "post" // CHECK-SAME-LITERAL: next_access = ["unknown"] memref.load %arg0[] {name = "post"} : memref @@ -499,3 +555,23 @@ func.func @store_with_a_region_after_containing_a_load(%arg0: memref) { memref.load %arg0[] {name = "post"} : memref return } + +// ----- + +func.func private @opaque_callee(%arg0: memref) + +// CHECK-LABEL: @call_opaque_callee +func.func @call_opaque_callee(%arg0: memref) { + // IP: name = "pre" + // IP-SAME: next_access = ["unknown"] + // IP_AR: name = "pre" + // IP_AR-SAME: next_access = {{\[}}["call"]] + // LOCAL: name = "pre" + // LOCAL-SAME: next_access = ["unknown"] + // LC_AR: name = "pre" + // LC_AR-SAME: next_access = {{\[}}["call"]] + memref.load %arg0[] {name = "pre"} : memref + func.call @opaque_callee(%arg0) {name = "call"} : (memref) -> () + memref.load %arg0[] {name = "post"} : memref + return +} diff --git a/mlir/test/Analysis/DataFlow/test-written-to.mlir b/mlir/test/Analysis/DataFlow/test-written-to.mlir index 82fe755aaf5d4..4fc9af164d48e 100644 --- a/mlir/test/Analysis/DataFlow/test-written-to.mlir +++ b/mlir/test/Analysis/DataFlow/test-written-to.mlir @@ -1,4 +1,28 @@ -// RUN: mlir-opt -split-input-file -test-written-to %s 2>&1 | FileCheck %s +// RUN: mlir-opt -split-input-file -test-written-to %s 2>&1 |\ +// RUN: FileCheck %s --check-prefixes=CHECK,IP +// RUN: mlir-opt -split-input-file -test-written-to='interprocedural=false' %s \ +// RUN: 2>&1 | FileCheck %s --check-prefixes=CHECK,LOCAL +// RUN: mlir-opt -split-input-file \ +// RUN: -test-written-to='assume-func-writes=true' %s 2>&1 |\ +// RUN: FileCheck %s --check-prefixes=CHECK,IP_AW +// RUN: mlir-opt -split-input-file \ +// RUN: -test-written-to='interprocedural=false assume-func-writes=true' \ +// RUN: %s 2>&1 | FileCheck %s --check-prefixes=CHECK,LC_AW + +// Check prefixes are as follows: +// 'check': common for all runs; +// 'ip': interprocedural runs; +// 'ip_aw': interpocedural runs assuming calls to external functions write to +// all arguments; +// 'local': local (non-interprocedural) analysis not assuming calls writing; +// 'lc_aw': local analysis assuming external calls writing to all arguments. + +// Note that despite the name of the test analysis being "written to", it is set +// up in a peculiar way where passing a value through a block or region argument +// (via visitCall/BranchOperand) is considered as "writing" that value to the +// corresponding operand, which is itself a value and not necessarily "memory". +// This is arguably okay for testing purposes, but may be surprising for readers +// trying to interpret this test using their intuition. // CHECK-LABEL: test_tag: constant0 // CHECK: result #0: [a] @@ -105,7 +129,9 @@ func.func @test_switch(%flag: i32, %m0: memref) { // ----- // CHECK-LABEL: test_tag: add -// CHECK: result #0: [a] +// IP: result #0: [a] +// LOCAL: result #0: [callarg0] +// LC_AW: result #0: [func.call] func.func @test_caller(%m0: memref, %arg: f32) { %0 = arith.addf %arg, %arg {tag = "add"} : f32 %1 = func.call @callee(%0) : (f32) -> f32 @@ -130,7 +156,9 @@ func.func private @callee(%0 : f32) -> f32 { } // CHECK-LABEL: test_tag: sub -// CHECK: result #0: [a] +// IP: result #0: [a] +// LOCAL: result #0: [callarg0] +// LC_AW: result #0: [func.call] func.func @test_caller_below_callee(%m0: memref, %arg: f32) { %0 = arith.subf %arg, %arg {tag = "sub"} : f32 %1 = func.call @callee(%0) : (f32) -> f32 @@ -155,7 +183,9 @@ func.func private @callee3(%0 : f32) -> f32 { } // CHECK-LABEL: test_tag: mul -// CHECK: result #0: [a] +// IP: result #0: [a] +// LOCAL: result #0: [callarg0] +// LC_AW: result #0: [func.call] func.func @test_callchain(%m0: memref, %arg: f32) { %0 = arith.mulf %arg, %arg {tag = "mul"} : f32 %1 = func.call @callee1(%0) : (f32) -> f32 @@ -239,19 +269,19 @@ func.func @test_for(%m0: memref) { // ----- // CHECK-LABEL: test_tag: default_a -// CHECK-LABEL: result #0: [a] +// CHECK: result #0: [a] // CHECK-LABEL: test_tag: default_b -// CHECK-LABEL: result #0: [b] +// CHECK: result #0: [b] // CHECK-LABEL: test_tag: 1a -// CHECK-LABEL: result #0: [a] +// CHECK: result #0: [a] // CHECK-LABEL: test_tag: 1b -// CHECK-LABEL: result #0: [b] +// CHECK: result #0: [b] // CHECK-LABEL: test_tag: 2a -// CHECK-LABEL: result #0: [a] +// CHECK: result #0: [a] // CHECK-LABEL: test_tag: 2b -// CHECK-LABEL: result #0: [b] +// CHECK: result #0: [b] // CHECK-LABEL: test_tag: switch -// CHECK-LABEL: operand #0: [brancharg0] +// CHECK: operand #0: [brancharg0] func.func @test_switch(%arg0 : index, %m0: memref) { %0, %1 = scf.index_switch %arg0 {tag="switch"} -> i32, i32 case 1 { @@ -276,6 +306,9 @@ func.func @test_switch(%arg0 : index, %m0: memref) { // ----- +// The point of this test is to ensure the analysis doesn't crash in presence of +// external functions. + // CHECK-LABEL: llvm.func @decl(i64) // CHECK-LABEL: llvm.func @func(%arg0: i64) { // CHECK-NEXT: llvm.call @decl(%arg0) : (i64) -> () @@ -295,12 +328,39 @@ func.func private @callee(%arg0 : i32, %arg1 : i32) -> i32 { } // CHECK-LABEL: test_tag: a -// CHECK-LABEL: operand #0: [b] -// CHECK-LABEL: operand #1: [] -// CHECK-LABEL: operand #2: [callarg2] -// CHECK-LABEL: result #0: [b] + +// IP: operand #0: [b] +// LOCAL: operand #0: [callarg0] +// LC_AW: operand #0: [test.call_on_device] + +// IP: operand #1: [] +// LOCAL: operand #1: [callarg1] +// LC_AW: operand #1: [test.call_on_device] + +// IP: operand #2: [callarg2] +// LOCAL: operand #2: [callarg2] +// LC_AW: operand #2: [test.call_on_device] + +// CHECK: result #0: [b] func.func @test_call_on_device(%arg0: i32, %arg1: i32, %device: i32, %m0: memref) { %0 = test.call_on_device @callee(%arg0, %arg1), %device {tag = "a"} : (i32, i32, i32) -> (i32) memref.store %0, %m0[] {tag_name = "b"} : memref return } + +// ----- + +func.func private @external_callee(%arg0: i32) -> i32 + +// CHECK-LABEL: test_tag: add_external +// IP: operand #0: [callarg0] +// LOCAL: operand #0: [callarg0] +// LC_AW: operand #0: [func.call] +// IP_AW: operand #0: [func.call] + +func.func @test_external_callee(%arg0: i32, %m0: memref) { + %0 = arith.addi %arg0, %arg0 { tag = "add_external"}: i32 + %1 = func.call @external_callee(%arg0) : (i32) -> i32 + memref.store %1, %m0[] {tag_name = "a"} : memref + return +} diff --git a/mlir/test/lib/Analysis/DataFlow/TestDenseBackwardDataFlowAnalysis.cpp b/mlir/test/lib/Analysis/DataFlow/TestDenseBackwardDataFlowAnalysis.cpp index 8bfd01d828060..ca052392f2f5f 100644 --- a/mlir/test/lib/Analysis/DataFlow/TestDenseBackwardDataFlowAnalysis.cpp +++ b/mlir/test/lib/Analysis/DataFlow/TestDenseBackwardDataFlowAnalysis.cpp @@ -49,7 +49,10 @@ class NextAccess : public AbstractDenseLattice, public AccessLatticeBase { class NextAccessAnalysis : public DenseBackwardDataFlowAnalysis { public: - using DenseBackwardDataFlowAnalysis::DenseBackwardDataFlowAnalysis; + NextAccessAnalysis(DataFlowSolver &solver, SymbolTableCollection &symbolTable, + bool assumeFuncReads = false) + : DenseBackwardDataFlowAnalysis(solver, symbolTable), + assumeFuncReads(assumeFuncReads) {} void visitOperation(Operation *op, const NextAccess &after, NextAccess *before) override; @@ -69,8 +72,10 @@ class NextAccessAnalysis : public DenseBackwardDataFlowAnalysis { // means "we don't know what the next access is" rather than "there is no next // access". But it's unclear how to differentiate the two cases... void setToExitState(NextAccess *lattice) override { - propagateIfChanged(lattice, lattice->reset()); + propagateIfChanged(lattice, lattice->setKnownToUnknown()); } + + const bool assumeFuncReads; }; } // namespace @@ -84,7 +89,13 @@ void NextAccessAnalysis::visitOperation(Operation *op, const NextAccess &after, SmallVector effects; memory.getEffects(effects); - ChangeResult result = before->meet(after); + + // First, check if all underlying values are already known. Otherwise, avoid + // propagating and stay in the "undefined" state to avoid incorrectly + // propagating values that may be overwritten later on as that could be + // problematic for convergence based on monotonicity of lattice updates. + SmallVector underlyingValues; + underlyingValues.reserve(effects.size()); for (const MemoryEffects::EffectInstance &effect : effects) { Value value = effect.getValue(); @@ -95,10 +106,23 @@ void NextAccessAnalysis::visitOperation(Operation *op, const NextAccess &after, // If cannot find the most underlying value, we cannot assume anything about // the next accesses. - value = UnderlyingValueAnalysis::getMostUnderlyingValue( - value, [&](Value value) { - return getOrCreateFor(op, value); - }); + std::optional underlyingValue = + UnderlyingValueAnalysis::getMostUnderlyingValue( + value, [&](Value value) { + return getOrCreateFor(op, value); + }); + + // If the underlying value is not known yet, don't propagate. + if (!underlyingValue) + return; + + underlyingValues.push_back(*underlyingValue); + } + + // Update the state if all underlying values are known. + ChangeResult result = before->meet(after); + for (const auto &[effect, value] : llvm::zip(effects, underlyingValues)) { + // If the underlying value is known to be unknown, set to fixpoint. if (!value) return setToExitState(before); @@ -110,6 +134,27 @@ void NextAccessAnalysis::visitOperation(Operation *op, const NextAccess &after, void NextAccessAnalysis::visitCallControlFlowTransfer( CallOpInterface call, CallControlFlowAction action, const NextAccess &after, NextAccess *before) { + if (action == CallControlFlowAction::ExternalCallee && assumeFuncReads) { + SmallVector underlyingValues; + underlyingValues.reserve(call->getNumOperands()); + for (Value operand : call.getArgOperands()) { + std::optional underlyingValue = + UnderlyingValueAnalysis::getMostUnderlyingValue( + operand, [&](Value value) { + return getOrCreateFor( + call.getOperation(), value); + }); + if (!underlyingValue) + return; + underlyingValues.push_back(*underlyingValue); + } + + ChangeResult result = before->meet(after); + for (Value operand : underlyingValues) { + result |= before->set(operand, call); + } + return propagateIfChanged(before, result); + } auto testCallAndStore = dyn_cast<::test::TestCallAndStoreOp>(call.getOperation()); if (testCallAndStore && ((action == CallControlFlowAction::EnterCallee && @@ -143,10 +188,24 @@ void NextAccessAnalysis::visitRegionBranchControlFlowTransfer( namespace { struct TestNextAccessPass : public PassWrapper> { + TestNextAccessPass() = default; + TestNextAccessPass(const TestNextAccessPass &other) : PassWrapper(other) { + interprocedural = other.interprocedural; + assumeFuncReads = other.assumeFuncReads; + } + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestNextAccessPass) StringRef getArgument() const override { return "test-next-access"; } + Option interprocedural{ + *this, "interprocedural", llvm::cl::init(true), + llvm::cl::desc("perform interprocedural analysis")}; + Option assumeFuncReads{ + *this, "assume-func-reads", llvm::cl::init(false), + llvm::cl::desc( + "assume external functions have read effect on all arguments")}; + static constexpr llvm::StringLiteral kTagAttrName = "name"; static constexpr llvm::StringLiteral kNextAccessAttrName = "next_access"; static constexpr llvm::StringLiteral kAtEntryPointAttrName = @@ -158,22 +217,29 @@ struct TestNextAccessPass if (!nextAccess) return StringAttr::get(op->getContext(), "not computed"); + // Note that if the underlying value could not be computed or is unknown, we + // conservatively treat the result also unknown. SmallVector attrs; for (Value operand : op->getOperands()) { - Value value = UnderlyingValueAnalysis::getMostUnderlyingValue( - operand, [&](Value value) { - return solver.lookupState(value); - }); - std::optional> nextAcc = - nextAccess->getAdjacentAccess(value); - if (!nextAcc) { + std::optional underlyingValue = + UnderlyingValueAnalysis::getMostUnderlyingValue( + operand, [&](Value value) { + return solver.lookupState(value); + }); + if (!underlyingValue) { + attrs.push_back(StringAttr::get(op->getContext(), "unknown")); + continue; + } + Value value = *underlyingValue; + const AdjacentAccess *nextAcc = nextAccess->getAdjacentAccess(value); + if (!nextAcc || !nextAcc->isKnown()) { attrs.push_back(StringAttr::get(op->getContext(), "unknown")); continue; } SmallVector innerAttrs; - innerAttrs.reserve(nextAcc->size()); - for (Operation *nextAccOp : *nextAcc) { + innerAttrs.reserve(nextAcc->get().size()); + for (Operation *nextAccOp : nextAcc->get()) { if (auto nextAccTag = nextAccOp->getAttrOfType(kTagAttrName)) { innerAttrs.push_back(nextAccTag); @@ -193,9 +259,10 @@ struct TestNextAccessPass Operation *op = getOperation(); SymbolTableCollection symbolTable; - DataFlowSolver solver; + auto config = DataFlowConfig().setInterprocedural(interprocedural); + DataFlowSolver solver(config); solver.load(); - solver.load(symbolTable); + solver.load(symbolTable, assumeFuncReads); solver.load(); solver.load(); if (failed(solver.initializeAndRun(op))) { diff --git a/mlir/test/lib/Analysis/DataFlow/TestDenseDataFlowAnalysis.h b/mlir/test/lib/Analysis/DataFlow/TestDenseDataFlowAnalysis.h index eab54fbcfbf4a..61ddc13f8a3d4 100644 --- a/mlir/test/lib/Analysis/DataFlow/TestDenseDataFlowAnalysis.h +++ b/mlir/test/lib/Analysis/DataFlow/TestDenseDataFlowAnalysis.h @@ -57,6 +57,62 @@ class UnderlyingValue { std::optional underlyingValue; }; +class AdjacentAccess { +public: + using DeterministicSetVector = + SetVector, + SmallPtrSet>; + + ArrayRef get() const { return accesses.getArrayRef(); } + bool isKnown() const { return !unknown; } + + ChangeResult merge(const AdjacentAccess &other) { + if (unknown) + return ChangeResult::NoChange; + if (other.unknown) { + unknown = true; + accesses.clear(); + return ChangeResult::Change; + } + + size_t sizeBefore = accesses.size(); + accesses.insert(other.accesses.begin(), other.accesses.end()); + return accesses.size() == sizeBefore ? ChangeResult::NoChange + : ChangeResult::Change; + } + + ChangeResult set(Operation *op) { + if (!unknown && accesses.size() == 1 && *accesses.begin() == op) + return ChangeResult::NoChange; + + unknown = false; + accesses.clear(); + accesses.insert(op); + return ChangeResult::Change; + } + + ChangeResult setUnknown() { + if (unknown) + return ChangeResult::NoChange; + + accesses.clear(); + unknown = true; + return ChangeResult::Change; + } + + bool operator==(const AdjacentAccess &other) const { + return unknown == other.unknown && accesses == other.accesses; + } + + bool operator!=(const AdjacentAccess &other) const { + return !operator==(other); + } + +private: + bool unknown = false; + DeterministicSetVector accesses; +}; + /// This lattice represents, for a given memory resource, the potential last /// operations that modified the resource. class AccessLatticeBase { @@ -73,40 +129,42 @@ class AccessLatticeBase { ChangeResult merge(const AccessLatticeBase &rhs) { ChangeResult result = ChangeResult::NoChange; for (const auto &mod : rhs.adjAccesses) { - auto &lhsMod = adjAccesses[mod.first]; - if (lhsMod != mod.second) { - lhsMod.insert(mod.second.begin(), mod.second.end()); - result |= ChangeResult::Change; - } + AdjacentAccess &lhsMod = adjAccesses[mod.first]; + result |= lhsMod.merge(mod.second); } return result; } /// Set the last modification of a value. ChangeResult set(Value value, Operation *op) { - auto &lastMod = adjAccesses[value]; + AdjacentAccess &lastMod = adjAccesses[value]; + return lastMod.set(op); + } + + ChangeResult setKnownToUnknown() { ChangeResult result = ChangeResult::NoChange; - if (lastMod.size() != 1 || *lastMod.begin() != op) { - result = ChangeResult::Change; - lastMod.clear(); - lastMod.insert(op); - } + for (auto &[value, adjacent] : adjAccesses) + result |= adjacent.setUnknown(); return result; } /// Get the adjacent accesses to a value. Returns std::nullopt if they /// are not known. - std::optional> getAdjacentAccess(Value value) const { + const AdjacentAccess *getAdjacentAccess(Value value) const { auto it = adjAccesses.find(value); if (it == adjAccesses.end()) - return {}; - return it->second.getArrayRef(); + return nullptr; + return &it->getSecond(); } void print(raw_ostream &os) const { for (const auto &lastMod : adjAccesses) { os << lastMod.first << ":\n"; - for (Operation *op : lastMod.second) + if (!lastMod.second.isKnown()) { + os << " \n"; + return; + } + for (Operation *op : lastMod.second.get()) os << " " << *op << "\n"; } } @@ -114,9 +172,7 @@ class AccessLatticeBase { private: /// The potential adjacent accesses to a memory resource. Use a set vector to /// keep the results deterministic. - DenseMap, - SmallPtrSet>> - adjAccesses; + DenseMap adjAccesses; }; /// Define the lattice class explicitly to provide a type ID. @@ -148,7 +204,7 @@ class UnderlyingValueAnalysis } /// Look for the most underlying value of a value. - static Value + static std::optional getMostUnderlyingValue(Value value, function_ref getUnderlyingValueFn) { @@ -156,7 +212,7 @@ class UnderlyingValueAnalysis do { underlying = getUnderlyingValueFn(value); if (!underlying || underlying->getValue().isUninitialized()) - return {}; + return std::nullopt; Value underlyingValue = underlying->getValue().getUnderlyingValue(); if (underlyingValue == value) break; diff --git a/mlir/test/lib/Analysis/DataFlow/TestDenseForwardDataFlowAnalysis.cpp b/mlir/test/lib/Analysis/DataFlow/TestDenseForwardDataFlowAnalysis.cpp index 2520ed3d83b9e..29480f5ad63ee 100644 --- a/mlir/test/lib/Analysis/DataFlow/TestDenseForwardDataFlowAnalysis.cpp +++ b/mlir/test/lib/Analysis/DataFlow/TestDenseForwardDataFlowAnalysis.cpp @@ -49,7 +49,9 @@ class LastModification : public AbstractDenseLattice, public AccessLatticeBase { class LastModifiedAnalysis : public DenseForwardDataFlowAnalysis { public: - using DenseForwardDataFlowAnalysis::DenseForwardDataFlowAnalysis; + explicit LastModifiedAnalysis(DataFlowSolver &solver, bool assumeFuncWrites) + : DenseForwardDataFlowAnalysis(solver), + assumeFuncWrites(assumeFuncWrites) {} /// Visit an operation. If the operation has no memory effects, then the state /// is propagated with no change. If the operation allocates a resource, then @@ -74,6 +76,9 @@ class LastModifiedAnalysis void setToEntryState(LastModification *lattice) override { propagateIfChanged(lattice, lattice->reset()); } + +private: + const bool assumeFuncWrites; }; } // end anonymous namespace @@ -89,7 +94,12 @@ void LastModifiedAnalysis::visitOperation(Operation *op, SmallVector effects; memory.getEffects(effects); - ChangeResult result = after->join(before); + // First, check if all underlying values are already known. Otherwise, avoid + // propagating and stay in the "undefined" state to avoid incorrectly + // propagating values that may be overwritten later on as that could be + // problematic for convergence based on monotonicity of lattice updates. + SmallVector underlyingValues; + underlyingValues.reserve(effects.size()); for (const auto &effect : effects) { Value value = effect.getValue(); @@ -100,10 +110,23 @@ void LastModifiedAnalysis::visitOperation(Operation *op, // If we cannot find the underlying value, we shouldn't just propagate the // effects through, return the pessimistic state. - value = UnderlyingValueAnalysis::getMostUnderlyingValue( - value, [&](Value value) { - return getOrCreateFor(op, value); - }); + std::optional underlyingValue = + UnderlyingValueAnalysis::getMostUnderlyingValue( + value, [&](Value value) { + return getOrCreateFor(op, value); + }); + + // If the underlying value is not yet known, don't propagate yet. + if (!underlyingValue) + return; + + underlyingValues.push_back(*underlyingValue); + } + + // Update the state when all underlying values are known. + ChangeResult result = after->join(before); + for (const auto &[effect, value] : llvm::zip(effects, underlyingValues)) { + // If the underlying value is known to be unknown, set to fixpoint state. if (!value) return setToEntryState(after); @@ -119,6 +142,26 @@ void LastModifiedAnalysis::visitOperation(Operation *op, void LastModifiedAnalysis::visitCallControlFlowTransfer( CallOpInterface call, CallControlFlowAction action, const LastModification &before, LastModification *after) { + if (action == CallControlFlowAction::ExternalCallee && assumeFuncWrites) { + SmallVector underlyingValues; + underlyingValues.reserve(call->getNumOperands()); + for (Value operand : call.getArgOperands()) { + std::optional underlyingValue = + UnderlyingValueAnalysis::getMostUnderlyingValue( + operand, [&](Value value) { + return getOrCreateFor( + call.getOperation(), value); + }); + if (!underlyingValue) + return; + underlyingValues.push_back(*underlyingValue); + } + + ChangeResult result = after->join(before); + for (Value operand : underlyingValues) + result |= after->set(operand, call); + return propagateIfChanged(after, result); + } auto testCallAndStore = dyn_cast<::test::TestCallAndStoreOp>(call.getOperation()); if (testCallAndStore && ((action == CallControlFlowAction::EnterCallee && @@ -155,21 +198,37 @@ struct TestLastModifiedPass : public PassWrapper> { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLastModifiedPass) + TestLastModifiedPass() = default; + TestLastModifiedPass(const TestLastModifiedPass &other) : PassWrapper(other) { + interprocedural = other.interprocedural; + assumeFuncWrites = other.assumeFuncWrites; + } + StringRef getArgument() const override { return "test-last-modified"; } + Option interprocedural{ + *this, "interprocedural", llvm::cl::init(true), + llvm::cl::desc("perform interprocedural analysis")}; + Option assumeFuncWrites{ + *this, "assume-func-writes", llvm::cl::init(false), + llvm::cl::desc( + "assume external functions have write effect on all arguments")}; + void runOnOperation() override { Operation *op = getOperation(); - DataFlowSolver solver; + DataFlowSolver solver(DataFlowConfig().setInterprocedural(interprocedural)); solver.load(); solver.load(); - solver.load(); + solver.load(assumeFuncWrites); solver.load(); if (failed(solver.initializeAndRun(op))) return signalPassFailure(); raw_ostream &os = llvm::errs(); + // Note that if the underlying value could not be computed or is unknown, we + // conservatively treat the result also unknown. op->walk([&](Operation *op) { auto tag = op->getAttrOfType("tag"); if (!tag) @@ -180,19 +239,29 @@ struct TestLastModifiedPass assert(lastMods && "expected a dense lattice"); for (auto [index, operand] : llvm::enumerate(op->getOperands())) { os << " operand #" << index << "\n"; - Value value = UnderlyingValueAnalysis::getMostUnderlyingValue( - operand, [&](Value value) { - return solver.lookupState(value); - }); + std::optional underlyingValue = + UnderlyingValueAnalysis::getMostUnderlyingValue( + operand, [&](Value value) { + return solver.lookupState(value); + }); + if (!underlyingValue) { + os << " - \n"; + continue; + } + Value value = *underlyingValue; assert(value && "expected an underlying value"); - if (std::optional> lastMod = + if (const AdjacentAccess *lastMod = lastMods->getAdjacentAccess(value)) { - for (Operation *lastModifier : *lastMod) { - if (auto tagName = - lastModifier->getAttrOfType("tag_name")) { - os << " - " << tagName.getValue() << "\n"; - } else { - os << " - " << lastModifier->getName() << "\n"; + if (!lastMod->isKnown()) { + os << " - \n"; + } else { + for (Operation *lastModifier : lastMod->get()) { + if (auto tagName = + lastModifier->getAttrOfType("tag_name")) { + os << " - " << tagName.getValue() << "\n"; + } else { + os << " - " << lastModifier->getName() << "\n"; + } } } } else { diff --git a/mlir/test/lib/Analysis/DataFlow/TestSparseBackwardDataFlowAnalysis.cpp b/mlir/test/lib/Analysis/DataFlow/TestSparseBackwardDataFlowAnalysis.cpp index f97a4c8bc5eb3..e1c60f06a6b5e 100644 --- a/mlir/test/lib/Analysis/DataFlow/TestSparseBackwardDataFlowAnalysis.cpp +++ b/mlir/test/lib/Analysis/DataFlow/TestSparseBackwardDataFlowAnalysis.cpp @@ -50,7 +50,10 @@ struct WrittenTo : public AbstractSparseLattice { /// is eventually written to. class WrittenToAnalysis : public SparseBackwardDataFlowAnalysis { public: - using SparseBackwardDataFlowAnalysis::SparseBackwardDataFlowAnalysis; + WrittenToAnalysis(DataFlowSolver &solver, SymbolTableCollection &symbolTable, + bool assumeFuncWrites) + : SparseBackwardDataFlowAnalysis(solver, symbolTable), + assumeFuncWrites(assumeFuncWrites) {} void visitOperation(Operation *op, ArrayRef operands, ArrayRef results) override; @@ -59,7 +62,13 @@ class WrittenToAnalysis : public SparseBackwardDataFlowAnalysis { void visitCallOperand(OpOperand &operand) override; + void visitExternalCall(CallOpInterface call, ArrayRef operands, + ArrayRef results) override; + void setToExitState(WrittenTo *lattice) override { lattice->writes.clear(); } + +private: + bool assumeFuncWrites; }; void WrittenToAnalysis::visitOperation(Operation *op, @@ -99,6 +108,26 @@ void WrittenToAnalysis::visitCallOperand(OpOperand &operand) { propagateIfChanged(lattice, lattice->addWrites(newWrites)); } +void WrittenToAnalysis::visitExternalCall(CallOpInterface call, + ArrayRef operands, + ArrayRef results) { + if (!assumeFuncWrites) { + return SparseBackwardDataFlowAnalysis::visitExternalCall(call, operands, + results); + } + + for (WrittenTo *lattice : operands) { + SetVector newWrites; + StringAttr name = call->getAttrOfType("tag_name"); + if (!name) { + name = StringAttr::get(call->getContext(), + call.getOperation()->getName().getStringRef()); + } + newWrites.insert(name); + propagateIfChanged(lattice, lattice->addWrites(newWrites)); + } +} + } // end anonymous namespace namespace { @@ -106,17 +135,31 @@ struct TestWrittenToPass : public PassWrapper> { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestWrittenToPass) + TestWrittenToPass() = default; + TestWrittenToPass(const TestWrittenToPass &other) : PassWrapper(other) { + interprocedural = other.interprocedural; + assumeFuncWrites = other.assumeFuncWrites; + } + StringRef getArgument() const override { return "test-written-to"; } + Option interprocedural{ + *this, "interprocedural", llvm::cl::init(true), + llvm::cl::desc("perform interprocedural analysis")}; + Option assumeFuncWrites{ + *this, "assume-func-writes", llvm::cl::init(false), + llvm::cl::desc( + "assume external functions have write effect on all arguments")}; + void runOnOperation() override { Operation *op = getOperation(); SymbolTableCollection symbolTable; - DataFlowSolver solver; + DataFlowSolver solver(DataFlowConfig().setInterprocedural(interprocedural)); solver.load(); solver.load(); - solver.load(symbolTable); + solver.load(symbolTable, assumeFuncWrites); if (failed(solver.initializeAndRun(op))) return signalPassFailure();