diff --git a/mlir/docs/Tutorials/DataFlowAnalysis.md b/mlir/docs/Tutorials/DataFlowAnalysis.md index ea7158fb7391d..0499a034bdbeb 100644 --- a/mlir/docs/Tutorials/DataFlowAnalysis.md +++ b/mlir/docs/Tutorials/DataFlowAnalysis.md @@ -5,20 +5,361 @@ daunting and/or complex. A dataflow analysis generally involves propagating information about the IR across various different types of control flow constructs, of which MLIR has many (Block-based branches, Region-based branches, CallGraph, etc), and it isn't always clear how best to go about performing the -propagation. To help writing these types of analyses in MLIR, this document -details several utilities that simplify the process and make it a bit more -approachable. +propagation. Dataflow analyses often require implementing fixed-point iteration +when data dependencies form cycles, as can happen with control-flow. Tracking +dependencies and making sure updates are properly propagated can get quite +difficult when writing complex analyses. That is why MLIR provides a framework +for writing general dataflow analyses as well as several utilities to streamline +the implementation of common analyses. The code and test from this tutorial can +be found in `mlir/examples/dataflow`. + +## DataFlow Analysis Framework + +MLIR provides a general dataflow analysis framework for building fixed-point +iteration dataflow analyses with ease and utilities for common dataflow +analyses. Because the landscape of IRs in MLIR can be vast, the framework is +designed to be extensible and composable, so that utilities can be shared across +dialects with different semantics as much as possible. The framework also tries +to make debugging dataflow analyses easy by providing (hopefully) insightful +logs with `-debug-only="dataflow"`. + +Suppose we want to compute at compile-time the constant-valued results of +operations. For example, consider: + +```mlir +%0 = string.constant "foo" +%1 = string.constant "bar" +%2 = string.concat %0, %1 +``` +We can determine with the information in the IR at compile time the value of +`%2` to be "foobar". This is called constant propagation. In MLIR's dataflow +analysis framework, this is in general called the "analysis state of a program +point"; the "state" being, in this case, the constant value, and the "program +point" being the SSA value `%2`. + +The constant value state of an SSA value is implemented as a subclass of +`AnalysisState`, and program points are represented by the `ProgramPoint` union, +which can be operations, SSA values, or blocks. They can also be just about +anything, see [Extending ProgramPoint](#extending-programpoint). In general, an +analysis state represents information about the IR computed by an analysis. + +Let us define an analysis state to represent a compile time known string value +of an SSA value: + +```c++ +class StringConstant : public AnalysisState { + /// This is the known string constant value of an SSA value at compile time + /// as determined by a dataflow analysis. To implement the concept of being + /// "uninitialized", the potential string value is wrapped in an `Optional` + /// and set to `None` by default to indicate that no value has been provided. + std::optional stringValue = std::nullopt; + +public: + using AnalysisState::AnalysisState; + + /// Return true if no value has been provided for the string constant value. + bool isUninitialized() const { return !stringValue.has_value(); } + + /// Default initialized the state to an empty string. Return whether the value + /// of the state has changed. + ChangeResult defaultInitialize() { + // If the state already has a value, do nothing. + if (!isUninitialized()) + return ChangeResult::NoChange; + // Initialize the state and indicate that its value changed. + stringValue = ""; + return ChangeResult::Change; + } + + /// Get the currently known string value. + StringRef getStringValue() const { + assert(!isUninitialized() && "getting the value of an uninitialized state"); + return stringValue.value(); + } + + /// "Join" the value of the state with another constant. + ChangeResult join(const Twine &value) { + // If the current state is uninitialized, just take the value. + if (isUninitialized()) { + stringValue = value.str(); + return ChangeResult::Change; + } + // If the current state is "overdefined", no new information can be taken. + if (stringValue->empty()) + return ChangeResult::NoChange; + // If the current state has a different value, it now has two conflicting + // values and should go to overdefined. + if (stringValue != value.str()) { + stringValue = ""; + return ChangeResult::Change; + } + return ChangeResult::NoChange; + } + + /// Print the constant value. + void print(raw_ostream &os) const override { + os << stringValue.value_or("") << "\n"; + } +}; +``` + +Analysis states often depend on each other. In our example, the constant value +of `%2` depends on that of `%0` and `%1`. It stands to reason that the constant +value of `%2` needs to be recomputed when that of `%0` and `%1` change. The +`DataFlowSolver` implements the fixed-point iteration algorithm and manages the +dependency graph between analysis states. + +The computation of analysis states, on the other hand, is performed by dataflow +analyses, subclasses of `DataFlowAnalysis`. A dataflow analysis has to implement +a "transfer function", that is, code that computes the values of some states +using the values of others, and set up the dependency graph correctly. Since the +dependency graph inside the solver is initially empty, it must also set up the +dependency graph. + +```c++ +class DataFlowAnalysis { +public: + /// "Visit" the provided program point. This method is typically used to + /// implement transfer functions on or across program points. + virtual LogicalResult visit(ProgramPoint point) = 0; + + /// Initialize the dependency graph required by this analysis from the given + /// top-level operation. This function is called once by the solver before + /// running the fixed-point iteration algorithm. + virtual LogicalResult initialize(Operation *top) = 0; + +protected: + /// Create a dependency between the given analysis state and lattice anchor + /// on this analysis. + void addDependency(AnalysisState *state, ProgramPoint *point); + + /// Propagate an update to a state if it changed. + void propagateIfChanged(AnalysisState *state, ChangeResult changed); + + /// Get the analysis state associated with the lattice anchor. The returned + /// state is expected to be "write-only", and any updates need to be + /// propagated by `propagateIfChanged`. + template + StateT *getOrCreate(AnchorT anchor) { + return solver.getOrCreateState(anchor); + } +}; +``` + +Dependency management is a little unusual in this framework. The dependents of +the value of a state are not other states but invocations of dataflow analyses +on certain program points. For example: + +```c++ +class StringConstantPropagation : public DataFlowAnalysis { +public: + /// Implement the transfer function for string operations. When visiting a + /// string operation, this analysis will try to determine compile time values + /// of the operation's results and set them in `StringConstant` states. This + /// function is invoked on an operation whenever the states of its operands + /// are changed. + LogicalResult visit(ProgramPoint point) override { + // This function expects only to receive operations. + auto *op = point->getPrevOp(); + + // Get or create the constant string values of the operands. + SmallVector operandValues; + for (Value operand : op->getOperands()) { + auto *value = getOrCreate(operand); + // Create a dependency from the state to this analysis. When the string + // value of one of the operation's operands are updated, invoke the + // transfer function again. + addDependency(value, point); + // If the state is uninitialized, bail out and come back later when it is + // initialized. + if (value->isUninitialized()) + return success(); + operandValues.push_back(value); + } + + // Try to compute a constant value of the result. + auto *result = getOrCreate(op->getResult(0)); + if (auto constant = dyn_cast(op)) { + // Just grab and set the constant value of the result of the operation. + // Propagate an update to the state if it changed. + propagateIfChanged(result, result->join(constant.getValue())); + } else if (auto concat = dyn_cast(op)) { + StringRef lhs = operandValues[0]->getStringValue(); + StringRef rhs = operandValues[1]->getStringValue(); + // If either operand is overdefined, the results are overdefined. + if (lhs.empty() || rhs.empty()) { + propagateIfChanged(result, result->defaultInitialize()); + + // Otherwise, compute the constant value and join it with the result. + } else { + propagateIfChanged(result, result->join(lhs + rhs)); + } + } else { + // We don't know how to implement the transfer function for this + // operation. Mark its results as overdefined. + propagateIfChanged(result, result->defaultInitialize()); + } + return success(); + } +}; +``` + +In the above example, the `visit` function sets up the dependencies of the +analysis invocation on an operation as the constant values of the operands of +each operation. When the operand states have initialized values but overdefined +values, it sets the state of the result to overdefined. Otherwise, it computes +the state of the result and merges the new information in with `join`. + +However, the dependency graph still needs to be initialized before the solver +knows what to call `visit` on. This is done in the `initialize` function: + +```c++ +LogicalResult StringConstantPropagation::initialize(Operation *top) { + // Visit every nested string operation and set up its dependencies. + top->walk([&](Operation *op) { + for (Value operand : op->getOperands()) { + auto *state = getOrCreate(operand); + addDependency(state, getProgramPointAfter(op)); + } + }); + // Now that the dependency graph has been set up, "seed" the evolution of the + // analysis by marking the constant values of all block arguments as + // overdefined and the results of (non-constant) operations with no operands. + auto defaultInitializeAll = [&](ValueRange values) { + for (Value value : values) { + auto *state = getOrCreate(value); + propagateIfChanged(state, state->defaultInitialize()); + } + }; + top->walk([&](Operation *op) { + for (Region ®ion : op->getRegions()) + for (Block &block : region) + defaultInitializeAll(block.getArguments()); + if (auto constant = dyn_cast(op)) { + auto *result = getOrCreate(constant.getResult()); + propagateIfChanged(result, result->join(constant.getValue())); + } else if (op->getNumOperands() == 0) { + defaultInitializeAll(op->getResults()); + } + }); + // The dependency graph has been set up and the analysis has been seeded. + // Finish initialization and let the solver run. + return success(); +} +``` + +Note that we can remove the call to `addDependency` inside our `visit` function +because the dependencies are set by the initialize function. Dependencies added +inside the `visit` function -- that is, while the solver is running -- are +called "dynamic dependencies". Dependending on the kind of analysis, it may be +more efficient to set some dependencies statically or dynamically. + +Another way to improve the efficiency of our analysis is to recognize that this +is a *sparse*, *forward* analysis. It is sparse because the dependencies of an +operation's transfer function are only the states of its operands, meaning that +we can track dependencies through the IR instead of relying on the solver to do +the bookkeeping. It is forward (assuming our IR has SSA dominance) because +information can only be propagated from an SSA value's definition to its users. + +That is a lot of code to write, however, so the framework comes with utilities +for implementing conditional sparse and dense dataflow analyses. See +[Sparse Forward DataFlowAnalysis](#sparse-forward-dataflow-analysis). + +### Running the Solver + +Setting up the dataflow solver is straightforward: + +```c++ +void MyPass::runOnOperation() { + Operation *top = getOperation(); + DataFlowSolver solver; + // Load the analysis. + solver.load(); + // Run the solver! + if (failed(solver.initializeAndRun(top))) + return signalPassFailure(); + // Query the results and do something... + top->walk([&](string::ConcatOp concat) { + auto *result = solver.lookupState(concat.getResult()); + // ... + }); +} +``` + +The following is a simple example. + +```mlir +func.func @single_concat() { + %1 = string.constant "hello " + %2 = string.constant "world." + %3 = string.concat %1, %2 + return +} +``` -## Forward Dataflow Analysis +The above IR will print the following after running pass. -One type of dataflow analysis is a forward propagation analysis. This type of -analysis, as the name may suggest, propagates information forward (e.g. from -definitions to uses). To provide a bit of concrete context, let's go over -writing a simple forward dataflow analysis in MLIR. Let's say for this analysis -that we want to propagate information about a special "metadata" dictionary -attribute. The contents of this attribute are simply a set of metadata that -describe a specific value, e.g. `metadata = { likes_pizza = true }`. We will -collect the `metadata` for operations in the IR and propagate them about. +```mlir +%0 = string.constant "hello " : hello +%1 = string.constant "world." : world. +%2 = string.concat %0, %1 : hello world. +``` + +### Extending ProgramPoint + +`ProgramPoint` can be extended to represent just about anything in a program: +control-flow edges or memory addresses. Custom "generic" program points are +implemented as subclasses of `GenericProgramPointBase`, a user of the storage +uniquer API, with a content-key. + +Example 1: a control-flow edge between two blocks. Suppose we want to represent +the state of an edge in the control-flow graph, such as its liveness. We can +attach such a state to the custom program point: + +```c++ +/// This program point represents a control-flow edge between two blocks. The +/// block `from` is a predecessor of `to`. +class CFGEdge + : public GenericLatticeAnchorBase> { +public: + Block *getFrom() const { return getValue().first; } + Block *getTo() const { return getValue().second; } +}; +``` + +Example 2: a raw memory address after the execution of an operation. This +program point allows us to attach states to a raw memory address before an +operation after an operation is executed. + +```c++ +class RawMemoryAddr : public GenericProgramPointBase< + RawMemoryAddr, std::pair> { /* ... */ }; +``` + +Instances of program points can be accessed as follows: + +```c++ +Block *from = /* ... */, *to = /* ... */; +auto *cfgEdge = solver.getProgramPoint(from, to); + +Operation *op = /* ... */; +auto *addr = solver.getProgramPoint(0x3000, op); +``` + +## Sparse Forward DataFlow Analysis + +One type of dataflow analysis is a sparse forward propagation analysis. This +type of analysis, as the name may suggest, propagates information forward (e.g. +from definitions to uses). The class `SparseDataFlowAnalysis` implements much of +the analysis logic, including handling control-flow, and abstracts away the +dependency management. + +To provide a bit of concrete context, let's go over writing a simple forward +dataflow analysis in MLIR. Let's say for this analysis that we want to propagate +information about a special "metadata" dictionary attribute. The contents of +this attribute are simply a set of metadata that describe a specific value, e.g. +`metadata = { likes_pizza = true }`. We will collect the `metadata` for +operations in the IR and propagate them about. ### Lattices @@ -72,31 +413,11 @@ held by an element of the lattice used by our dataflow analysis: struct MetadataLatticeValue { MetadataLatticeValue() = default; /// Compute a lattice value from the provided dictionary. - MetadataLatticeValue(DictionaryAttr attr) - : metadata(attr.begin(), attr.end()) {} - - /// Return a pessimistic value state, i.e. the `top`/`overdefined`/`unknown` - /// state, for our value type. The resultant state should not assume any - /// information about the state of the IR. - static MetadataLatticeValue getPessimisticValueState(MLIRContext *context) { - // The `top`/`overdefined`/`unknown` state is when we know nothing about any - // metadata, i.e. an empty dictionary. - return MetadataLatticeValue(); - } - /// Return a pessimistic value state for our value type using only information - /// about the state of the provided IR. This is similar to the above method, - /// but may produce a slightly more refined result. This is okay, as the - /// information is already encoded as fact in the IR. - static MetadataLatticeValue getPessimisticValueState(Value value) { - // Check to see if the parent operation has metadata. - if (Operation *parentOp = value.getDefiningOp()) { - if (auto metadata = parentOp->getAttrOfType("metadata")) - return MetadataLatticeValue(metadata); - - // If no metadata is present, fallback to the - // `top`/`overdefined`/`unknown` state. + MetadataLatticeValue(DictionaryAttr attr) { + for (NamedAttribute pair : attr) { + metadata.insert( + std::pair(pair.getName(), pair.getValue())); } - return MetadataLatticeValue(); } /// This method conservatively joins the information held by `lhs` and `rhs` @@ -110,33 +431,48 @@ struct MetadataLatticeValue { /// * monotonicity: join(x, join(x,y)) == join(x,y) static MetadataLatticeValue join(const MetadataLatticeValue &lhs, const MetadataLatticeValue &rhs) { - // To join `lhs` and `rhs` we will define a simple policy, which is that we - // only keep information that is the same. This means that we only keep - // facts that are true in both. - MetadataLatticeValue result; - for (const auto &lhsIt : lhs.metadata) { - // As noted above, we only merge if the values are the same. - auto it = rhs.metadata.find(lhsIt.first); - if (it == rhs.metadata.end() || it.second != lhsIt.second) - continue; - result.insert(lhsIt); - } - return result; + // To join `lhs` and `rhs` we will define a simple policy, which is that we + // directly insert the metadata of rhs into the metadata of lhs.If lhs and rhs + // have overlapping attributes, keep the attribute value in lhs unchanged. + MetadataLatticeValue result; + for (auto &&lhsIt : lhs.metadata) { + result.metadata.insert( + std::pair(lhsIt.first, lhsIt.second)); } + for (auto &&rhsIt : rhs.metadata) { + result.metadata.insert( + std::pair(rhsIt.first, rhsIt.second)); + } + return result; +} + /// A simple comparator that checks to see if this value is equal to the one /// provided. bool operator==(const MetadataLatticeValue &rhs) const { - if (metadata.size() != rhs.metadata.size()) + if (metadata.size() != rhs.metadata.size()) + return false; + + // Check that `rhs` contains the same metadata. + for (auto &&it : metadata) { + auto rhsIt = rhs.metadata.find(it.first); + if (rhsIt == rhs.metadata.end() || it.second != rhsIt->second) return false; - // Check that `rhs` contains the same metadata. - for (const auto &it : metadata) { - auto rhsIt = rhs.metadata.find(it.first); - if (rhsIt == rhs.metadata.end() || it.second != rhsIt.second) - return false; - } - return true; } + return true; +} + + /// Print data in metadata. + void print(llvm::raw_ostream &os) const { + SmallVector metadataKey(metadata.keys()); + std::sort(metadataKey.begin(), metadataKey.end(), + [&](StringAttr a, StringAttr b) { return a < b; }); + os << "{"; + for (StringAttr key : metadataKey) { + os << key << ": " << metadata.at(key) << ", "; + } + os << "\b\b}\n"; +} /// Our value represents the combined metadata, which is originally a /// DictionaryAttr, so we use a map. @@ -154,7 +490,7 @@ shown below: /// This class represents a lattice element holding a specific value of type /// `ValueT`. template -class LatticeElement ... { +class Lattice ... { public: /// Return the value held by this element. This requires that a value is /// known, i.e. not `uninitialized`. @@ -168,20 +504,25 @@ public: /// Join the information contained in the 'rhs' value into this /// lattice. Returns if the state of the current lattice changed. ChangeResult join(const ValueT &rhs); - - /// Mark the lattice element as having reached a pessimistic fixpoint. This - /// means that the lattice may potentially have conflicting value states, and - /// only the conservatively known value state should be relied on. - ChangeResult markPessimisticFixPoint(); + + ... }; ``` With our lattice defined, we can now define the driver that will compute and -propagate our lattice across the IR. +propagate our lattice across the IR. The following is our definition of metadata +lattice. + +```c++ +class MetadataLatticeValueLattice : public Lattice { +public: + using Lattice::Lattice; +}; +``` -### ForwardDataflowAnalysis Driver +### SparseForwardDataFlowAnalysis Driver -The `ForwardDataFlowAnalysis` class represents the driver of the dataflow +The `SparseForwardDataFlowAnalysis` class represents the driver of the dataflow analysis, and performs all of the related analysis computation. When defining our analysis, we will inherit from this class and implement some of its hooks. Before that, let's look at a quick overview of this class and some of the @@ -190,42 +531,36 @@ important API for our analysis: ```c++ /// This class represents the main driver of the forward dataflow analysis. It /// takes as a template parameter the value type of lattice being computed. -template -class ForwardDataFlowAnalysis : ... { +template +class SparseForwardDataFlowAnalysis : ... { public: - ForwardDataFlowAnalysis(MLIRContext *context); - - /// Compute the analysis on operations rooted under the given top-level - /// operation. Note that the top-level operation is not visited. - void run(Operation *topLevelOp); + explicit SparseForwardDataFlowAnalysis(DataFlowSolver &solver) + : AbstractSparseForwardDataFlowAnalysis(solver) {} + + /// Visit an operation with the lattices of its operands. This function is + /// expected to set the lattices of the operation's results. + virtual LogicalResult visitOperation(Operation *op, + ArrayRef operands, + ArrayRef results) = 0; + ... +protected: /// Return the lattice element attached to the given value. If a lattice has /// not been added for the given value, a new 'uninitialized' value is /// inserted and returned. - LatticeElement &getLatticeElement(Value value); + StateT *getLatticeElement(Value value); - /// Return the lattice element attached to the given value, or nullptr if no - /// lattice element for the value has yet been created. - LatticeElement *lookupLatticeElement(Value value); + /// Get the lattice element for a value and create a dependency on the + /// provided program point. + const StateT *getLatticeElementFor(ProgramPoint *point, Value value); - /// Mark all of the lattice elements for the given range of Values as having - /// reached a pessimistic fixpoint. - ChangeResult markAllPessimisticFixPoint(ValueRange values); - -protected: - /// Visit the given operation, and join any necessary analysis state - /// into the lattice elements for the results and block arguments owned by - /// this operation using the provided set of operand lattice elements - /// (all pointer values are guaranteed to be non-null). Returns if any result - /// or block argument value lattice elements changed during the visit. The - /// lattice element for a result or block argument value can be obtained, and - /// join'ed into, by using `getLatticeElement`. - virtual ChangeResult visitOperation( - Operation *op, ArrayRef *> operands) = 0; + /// Set the given lattice element(s) at control flow entry point(s). + virtual void setToEntryState(StateT *lattice) = 0; + ... }; ``` -NOTE: Some API has been redacted for our example. The `ForwardDataFlowAnalysis` +NOTE: Some API has been redacted for our example. The `SparseForwardDataFlowAnalysis` contains various other hooks that allow for injecting custom behavior when applicable. @@ -237,60 +572,92 @@ function for the operation, that is specific to our analysis. A simple implementation for our example is shown below: ```c++ -class MetadataAnalysis : public ForwardDataFlowAnalysis { +class MetadataAnalysis + : public SparseForwardDataFlowAnalysis { public: - using ForwardDataFlowAnalysis::ForwardDataFlowAnalysis; - - ChangeResult visitOperation( - Operation *op, ArrayRef *> operands) override { - DictionaryAttr metadata = op->getAttrOfType("metadata"); - - // If we have no metadata for this operation, we will conservatively mark - // all of the results as having reached a pessimistic fixpoint. - if (!metadata) - return markAllPessimisticFixPoint(op->getResults()); - - // Otherwise, we will compute a lattice value for the metadata and join it - // into the current lattice element for all of our results. - MetadataLatticeValue latticeValue(metadata); - ChangeResult result = ChangeResult::NoChange; - for (Value value : op->getResults()) { - // We grab the lattice element for `value` via `getLatticeElement` and - // then join it with the lattice value for this operation's metadata. Note - // that during the analysis phase, it is fine to freely create a new - // lattice element for a value. This is why we don't use the - // `lookupLatticeElement` method here. - result |= getLatticeElement(value).join(latticeValue); - } - return result; + using SparseForwardDataFlowAnalysis::SparseForwardDataFlowAnalysis; + LogicalResult + visitOperation(Operation *op, + ArrayRef operands, + ArrayRef results) override { + DictionaryAttr metadata = op->getAttrOfType("metadata"); + // If we have no metadata for this operation and the operands is empty, we + // will conservatively mark all of the results as having reached a pessimistic + // fixpoint. + if (!metadata && operands.empty()) { + setAllToEntryStates(results); + return success(); + } + + MetadataLatticeValue latticeValue; + if (metadata) + latticeValue = MetadataLatticeValue(metadata); + + // Otherwise, we will compute a lattice value for the metadata and join it + // into the current lattice element for all of our results.`results` stores + // the lattices corresponding to the results of op, We use a loop to traverse + // them. + for (MetadataLatticeValueLattice *: results) { + + // `isChanged` records whether the result has been changed. + ChangeResult isChanged = ChangeResult::NoChange; + + // Op's metadata is joined result's lattice. + isChanged |= result->join(latticeValue); + + // All lattice of operands of op are joined to the lattice of result. + for (auto operand : operands) + isChanged |= result->join(*operand); + + propagateIfChanged(result, isChanged); + } + return success(); } }; ``` With that, we have all of the necessary components to compute our analysis. -After the analysis has been computed, we can grab any computed information for -values by using `lookupLatticeElement`. We use this function over -`getLatticeElement` as the analysis is not guaranteed to visit all values, e.g. -if the value is in a unreachable block, and we don't want to create a new -uninitialized lattice element in this case. See below for a quick example: +After the analysis has been computed, we need to run our analysis using +`DataFlowSolver`, and we can grab any computed information for values by +using `lookupState`. See below for a quick example, after the pass runs the +analysis, we print the metadata of each op's results. ```c++ void MyPass::runOnOperation() { - MetadataAnalysis analysis(&getContext()); - analysis.run(getOperation()); + Operation *op = getOperation(); + DataFlowSolver solver; + solver.load(); + solver.load(); + if (failed(solver.initializeAndRun(op))) + return signalPassFailure(); + + // If an op has more than one result, then the lattice is the same for each + // result, and we just print one of the results. + op->walk([&](Operation *op) { + if (op->getNumResults()) { + Value result = op->getResult(0); + auto lattice = solver.lookupState(result); + llvm::outs() << OpWithFlags(op, OpPrintingFlags().skipRegions()) << " : "; + lattice->print(llvm::outs()); + } + }); ... } +``` -void MyPass::useAnalysisOn(MetadataAnalysis &analysis, Value value) { - LatticeElement *latticeElement = analysis.lookupLatticeElement(value); - - // If we don't have an element, the `value` wasn't visited during our analysis - // meaning that it could be dead. We need to treat this conservatively. - if (!lattice) - return; +The following is a simple example. More tests can be found in the `mlir/Example/dataflow`. - // Our lattice element has a value, use it: - MetadataLatticeValue &value = lattice->getValue(); - ... +```mlir +func.func @single_join(%arg0 : index, %arg1 : index) -> index { + %1 = arith.addi %arg0, %arg1 {metadata = { likes_pizza = true }} : index + %2 = arith.addi %1, %arg1 : index + return %2 : index } ``` + +The above IR will print the following after running pass. + +```mlir +%0 = arith.addi %arg0, %arg1 {metadata = {likes_pizza = true}} : index : {"likes_pizza": true} +%1 = arith.addi %0, %arg1 : index : {"likes_pizza": true} +``` diff --git a/mlir/examples/CMakeLists.txt b/mlir/examples/CMakeLists.txt index 2a1cac34d8c29..6ea7c20188eb6 100644 --- a/mlir/examples/CMakeLists.txt +++ b/mlir/examples/CMakeLists.txt @@ -1,3 +1,4 @@ +add_subdirectory(dataflow) add_subdirectory(toy) add_subdirectory(transform) add_subdirectory(transform-opt) diff --git a/mlir/examples/dataflow/CMakeLists.txt b/mlir/examples/dataflow/CMakeLists.txt new file mode 100644 index 0000000000000..ad2ba8d90087a --- /dev/null +++ b/mlir/examples/dataflow/CMakeLists.txt @@ -0,0 +1,9 @@ +add_custom_target(DataFlowExample) +set_target_properties(DataFlowExample PROPERTIES FOLDER "MLIR/Examples") + +include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include) +include_directories(${CMAKE_CURRENT_BINARY_DIR}/include) + +add_subdirectory(include) +add_subdirectory(lib) +add_subdirectory(dataflow-opt) diff --git a/mlir/examples/dataflow/dataflow-opt/CMakeLists.txt b/mlir/examples/dataflow/dataflow-opt/CMakeLists.txt new file mode 100644 index 0000000000000..dbb5e3a2b2b30 --- /dev/null +++ b/mlir/examples/dataflow/dataflow-opt/CMakeLists.txt @@ -0,0 +1,13 @@ +add_dependencies(DataFlowExample dataflow-opt) +add_llvm_example(dataflow-opt + dataflow-opt.cpp +) + +target_link_libraries(dataflow-opt + PRIVATE + MLIRIR + MLIRMlirOptMain + MLIRStringDialect + MLIRTestStringConstantPropagation + MLIRTestMetadataAnalysisPass +) diff --git a/mlir/examples/dataflow/dataflow-opt/dataflow-opt.cpp b/mlir/examples/dataflow/dataflow-opt/dataflow-opt.cpp new file mode 100644 index 0000000000000..c4ab2d40840ad --- /dev/null +++ b/mlir/examples/dataflow/dataflow-opt/dataflow-opt.cpp @@ -0,0 +1,44 @@ +//===-- dataflow-opt.cpp - dataflow tutorial entry point ------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This is the top-level file for the dataflow tutorial. +// +//===----------------------------------------------------------------------===// + +#include "StringDialect.h" +#include "mlir/IR/DialectRegistry.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/InitAllDialects.h" +#include "mlir/InitAllExtensions.h" +#include "mlir/Tools/mlir-opt/MlirOptMain.h" +#include "mlir/Transforms/Passes.h" + +namespace mlir { +namespace test { +void registerTestMetadataAnalysisPass(); +void registerTestStringConstantPropagation(); +}; // namespace test +} // namespace mlir + +int main(int argc, char *argv[]) { + // Register all MLIR core dialects. + mlir::DialectRegistry registry; + registerAllDialects(registry); + registerAllExtensions(registry); + + // Register String dialect. + registry.insert(); + + // Register test-string-constant-propagation pass. + mlir::test::registerTestStringConstantPropagation(); + + // Register test-metadata-analysis pass. + mlir::test::registerTestMetadataAnalysisPass(); + return mlir::failed( + mlir::MlirOptMain(argc, argv, "dataflow-opt optimizer driver", registry)); +} diff --git a/mlir/examples/dataflow/include/CMakeLists.txt b/mlir/examples/dataflow/include/CMakeLists.txt new file mode 100644 index 0000000000000..c3c7b9d534561 --- /dev/null +++ b/mlir/examples/dataflow/include/CMakeLists.txt @@ -0,0 +1,2 @@ +add_mlir_dialect(StringOps string) +add_mlir_doc(StringOps StringDialect Dialects/ -gen-dialect-doc) diff --git a/mlir/examples/dataflow/include/MetadataAnalysis.h b/mlir/examples/dataflow/include/MetadataAnalysis.h new file mode 100644 index 0000000000000..f86aae71f9d2b --- /dev/null +++ b/mlir/examples/dataflow/include/MetadataAnalysis.h @@ -0,0 +1,79 @@ +//===-- MetadataAnalysis.h - dataflow tutorial ------------------*- c++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file is contains the dataflow tutorial's classes related to metadata. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TUTORIAL_METADATA_ANALYSIS_H_ +#define MLIR_TUTORIAL_METADATA_ANALYSIS_H_ + +#include "mlir/Analysis/DataFlow/SparseAnalysis.h" +#include "llvm/Support/raw_ostream.h" + +using namespace mlir; +using namespace llvm; + +namespace mlir { +/// The value of our lattice represents the inner structure of a DictionaryAttr, +/// for the `metadata`. +struct MetadataLatticeValue { + MetadataLatticeValue() = default; + /// Compute a lattice value from the provided dictionary. + MetadataLatticeValue(DictionaryAttr attr) { + for (NamedAttribute pair : attr) { + metadata.insert( + std::pair(pair.getName(), pair.getValue())); + } + } + + /// This method conservatively joins the information held by `lhs` and `rhs` + /// into a new value. This method is required to be monotonic. `monotonicity` + /// is implied by the satisfaction of the following axioms: + /// * idempotence: join(x,x) == x + /// * commutativity: join(x,y) == join(y,x) + /// * associativity: join(x,join(y,z)) == join(join(x,y),z) + /// + /// When the above axioms are satisfied, we achieve `monotonicity`: + /// * monotonicity: join(x, join(x,y)) == join(x,y) + static MetadataLatticeValue join(const MetadataLatticeValue &lhs, + const MetadataLatticeValue &rhs); + + /// A simple comparator that checks to see if this value is equal to the one + /// provided. + bool operator==(const MetadataLatticeValue &rhs) const; + + /// Print data in metadata. + void print(llvm::raw_ostream &os) const; + + /// Our value represents the combined metadata, which is originally a + /// DictionaryAttr, so we use a map. + DenseMap metadata; +}; + +namespace dataflow { +class MetadataLatticeValueLattice : public Lattice { +public: + using Lattice::Lattice; +}; + +class MetadataAnalysis + : public SparseForwardDataFlowAnalysis { +public: + using SparseForwardDataFlowAnalysis::SparseForwardDataFlowAnalysis; + LogicalResult + visitOperation(Operation *op, + ArrayRef operands, + ArrayRef results) override; + void setToEntryState(MetadataLatticeValueLattice *lattice) override; +}; + +} // namespace dataflow +} // namespace mlir + +#endif // MLIR_TUTORIAL_METADATA_ANALYSIS_H_ diff --git a/mlir/examples/dataflow/include/StringConstantPropagation.h b/mlir/examples/dataflow/include/StringConstantPropagation.h new file mode 100644 index 0000000000000..23a629c3882d0 --- /dev/null +++ b/mlir/examples/dataflow/include/StringConstantPropagation.h @@ -0,0 +1,95 @@ +//===-- StringConstantPropagation.h - dataflow tutorial ---------*- c++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file is contains the dataflow tutorial's classes related to string +// constant propagation. +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TUTORIAL_STRING_CONSTANT_PROPAGATION_H_ +#define MLIR_TUTORIAL_STRING_CONSTANT_PROPAGATION_H_ + +#include "mlir/Analysis/DataFlowFramework.h" + +namespace mlir { +namespace dataflow { + +class StringConstant : public AnalysisState { + /// This is the known string constant value of an SSA value at compile time + /// as determined by a dataflow analysis. To implement the concept of being + /// "uninitialized", the potential string value is wrapped in an `Optional` + /// and set to `None` by default to indicate that no value has been provided. + std::optional stringValue = std::nullopt; + +public: + using AnalysisState::AnalysisState; + + /// Return true if no value has been provided for the string constant value. + bool isUninitialized() const { return !stringValue.has_value(); } + + /// Default initialized the state to an empty string. Return whether the value + /// of the state has changed. + ChangeResult defaultInitialize() { + // If the state already has a value, do nothing. + if (!isUninitialized()) + return ChangeResult::NoChange; + // Initialize the state and indicate that its value changed. + stringValue = ""; + return ChangeResult::Change; + } + + /// Get the currently known string value. + StringRef getStringValue() const { + assert(!isUninitialized() && "getting the value of an uninitialized state"); + return stringValue.value(); + } + + /// "Join" the value of the state with another constant. + ChangeResult join(const Twine &value) { + // If the current state is uninitialized, just take the value. + if (isUninitialized()) { + stringValue = value.str(); + return ChangeResult::Change; + } + // If the current state is "overdefined", no new information can be taken. + if (stringValue->empty()) + return ChangeResult::NoChange; + // If the current state has a different value, it now has two conflicting + // values and should go to overdefined. + if (stringValue != value.str()) { + stringValue = ""; + return ChangeResult::Change; + } + return ChangeResult::NoChange; + } + + /// Print the constant value. + void print(raw_ostream &os) const override { + os << stringValue.value_or("") << "\n"; + } +}; + +class StringConstantPropagation : public DataFlowAnalysis { +public: + using DataFlowAnalysis::DataFlowAnalysis; + + /// Implement the transfer function for string operations. When visiting a + /// string operation, this analysis will try to determine compile time values + /// of the operation's results and set them in `StringConstant` states. This + /// function is invoked on an operation whenever the states of its operands + /// are changed. + LogicalResult visit(ProgramPoint *point) override; + + /// Initialize the analysis by visiting every operation with potential + /// control-flow semantics. + LogicalResult initialize(Operation *top) override; +}; + +} // namespace dataflow +} // namespace mlir + +#endif // MLIR_TUTORIAL_STRING_CONSTANT_PROPAGATION_H_ diff --git a/mlir/examples/dataflow/include/StringDialect.h b/mlir/examples/dataflow/include/StringDialect.h new file mode 100644 index 0000000000000..51ceb3180453b --- /dev/null +++ b/mlir/examples/dataflow/include/StringDialect.h @@ -0,0 +1,30 @@ +//===- StringDialect.h - Dialect definition for the String IR -------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the String Dialect for the StringConstantPropagation. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TUTORIAL_STRING_DIALECT_H_ +#define MLIR_TUTORIAL_STRING_DIALECT_H_ + +#include "mlir/Bytecode/BytecodeOpInterface.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" + +#define GET_TYPEDEF_CLASSES +#include "StringOpsTypes.h.inc" + +#include "StringOpsDialect.h.inc" +#define GET_OP_CLASSES +#include "StringOps.h.inc" + +#endif // MLIR_TUTORIAL_STRING_DIALECT_H_ diff --git a/mlir/examples/dataflow/include/StringOps.td b/mlir/examples/dataflow/include/StringOps.td new file mode 100644 index 0000000000000..bcfa6c0d3fe0d --- /dev/null +++ b/mlir/examples/dataflow/include/StringOps.td @@ -0,0 +1,75 @@ +//===-- StringOps.td - String dialect operation definitions *- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the basic operations for the String dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef STRING_OPS +#define STRING_OPS + +include "mlir/IR/OpBase.td" +include "mlir/IR/AttrTypeBase.td" +include "mlir/Interfaces/SideEffectInterfaces.td" + +def String_Dialect : Dialect { + let name = "string"; + let cppNamespace = "mlir::string"; + let description = [{ + The `String` dialect provides string-related operations, + which are used for string constant propagation. + }]; + let useDefaultTypePrinterParser = 1; +} + +//===----------------------------------------------------------------------===// +// String Type Definitions +//===----------------------------------------------------------------------===// + +def StringType : TypeDef { + let mnemonic = "str"; + let summary = "string literal"; + let description = [{ + `string.str` is a type returned by ops of string dialect. + }]; +} + +//===----------------------------------------------------------------------===// +// String Op Definitions +//===----------------------------------------------------------------------===// + +class String_Op traits = [Pure]> : + Op {} + +def ConstantOp : String_Op<"constant"> { + let arguments = (ins StrAttr:$value); + let results = (outs StringType:$res); + let description = [{ + The `string.constant` op is used to create string constants. + ```mlir + %bar = string.constant "bar" + ``` + }]; + let assemblyFormat = "$value attr-dict"; +} + +def ConcatOp : String_Op<"concat"> { + let arguments = (ins StringType:$lhs, StringType:$rhs); + let results = (outs StringType:$res); + let description = [{ + The `string.concat` op is used to concatenate strings. + ```mlir + %bar = string.constant "bar" + %foo = string.constant "foo" + %concat = string.concat %bar, %foo + ``` + }]; + let assemblyFormat = "$lhs `,` $rhs attr-dict"; +} + +#endif // STRING_OPS diff --git a/mlir/examples/dataflow/lib/Analysis/CMakeLists.txt b/mlir/examples/dataflow/lib/Analysis/CMakeLists.txt new file mode 100644 index 0000000000000..861440e371141 --- /dev/null +++ b/mlir/examples/dataflow/lib/Analysis/CMakeLists.txt @@ -0,0 +1,2 @@ +add_subdirectory(StringConstantPropagation) +add_subdirectory(MetadataAnalysis) diff --git a/mlir/examples/dataflow/lib/Analysis/MetadataAnalysis/CMakeLists.txt b/mlir/examples/dataflow/lib/Analysis/MetadataAnalysis/CMakeLists.txt new file mode 100644 index 0000000000000..c9e07f520e0be --- /dev/null +++ b/mlir/examples/dataflow/lib/Analysis/MetadataAnalysis/CMakeLists.txt @@ -0,0 +1,6 @@ +add_mlir_example_library(MLIRMetadataAnalysis + MetadataAnalysis.cpp + + LINK_LIBS PUBLIC + MLIRAnalysis + ) diff --git a/mlir/examples/dataflow/lib/Analysis/MetadataAnalysis/MetadataAnalysis.cpp b/mlir/examples/dataflow/lib/Analysis/MetadataAnalysis/MetadataAnalysis.cpp new file mode 100644 index 0000000000000..100d5b8214555 --- /dev/null +++ b/mlir/examples/dataflow/lib/Analysis/MetadataAnalysis/MetadataAnalysis.cpp @@ -0,0 +1,117 @@ +//===-- MetadataAnalysis.cpp - dataflow tutorial ----------------*- c++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file is contains the implementations of the methods in the +// metadata-related classes in the dataflow tutorial. +// +//===----------------------------------------------------------------------===// + +#include "MetadataAnalysis.h" + +using namespace mlir; +using namespace dataflow; + +namespace mlir { + +/// This method conservatively joins the information held by `lhs` and `rhs` +/// into a new value. This method is required to be monotonic. `monotonicity` +/// is implied by the satisfaction of the following axioms: +/// * idempotence: join(x,x) == x +/// * commutativity: join(x,y) == join(y,x) +/// * associativity: join(x,join(y,z)) == join(join(x,y),z) +/// +/// When the above axioms are satisfied, we achieve `monotonicity`: +/// * monotonicity: join(x, join(x,y)) == join(x,y) +MetadataLatticeValue +MetadataLatticeValue::join(const MetadataLatticeValue &lhs, + const MetadataLatticeValue &rhs) { + // To join `lhs` and `rhs` we will define a simple policy, which is that we + // directly insert the metadata of rhs into the metadata of lhs.If lhs and rhs + // have overlapping attributes, keep the attribute value in lhs unchanged. + MetadataLatticeValue result; + for (auto &&lhsIt : lhs.metadata) { + result.metadata.insert( + std::pair(lhsIt.first, lhsIt.second)); + } + + for (auto &&rhsIt : rhs.metadata) { + result.metadata.insert( + std::pair(rhsIt.first, rhsIt.second)); + } + return result; +} + +/// A simple comparator that checks to see if this value is equal to the one +/// provided. +bool MetadataLatticeValue::operator==(const MetadataLatticeValue &rhs) const { + if (metadata.size() != rhs.metadata.size()) + return false; + + // Check that `rhs` contains the same metadata. + for (auto &&it : metadata) { + auto rhsIt = rhs.metadata.find(it.first); + if (rhsIt == rhs.metadata.end() || it.second != rhsIt->second) + return false; + } + return true; +} + +void MetadataLatticeValue::print(llvm::raw_ostream &os) const { + SmallVector metadataKey(metadata.keys()); + std::sort(metadataKey.begin(), metadataKey.end(), + [&](StringAttr a, StringAttr b) { return a < b; }); + os << "{"; + for (StringAttr key : metadataKey) { + os << key << ": " << metadata.at(key) << ", "; + } + os << "\b\b}\n"; +} + +namespace dataflow { +LogicalResult MetadataAnalysis::visitOperation( + Operation *op, ArrayRef operands, + ArrayRef results) { + DictionaryAttr metadata = op->getAttrOfType("metadata"); + // If we have no metadata for this operation and the operands is empty, we + // will conservatively mark all of the results as having reached a pessimistic + // fixpoint. + if (!metadata && operands.empty()) { + setAllToEntryStates(results); + return success(); + } + + MetadataLatticeValue latticeValue; + if (metadata) + latticeValue = MetadataLatticeValue(metadata); + + // Otherwise, we will compute a lattice value for the metadata and join it + // into the current lattice element for all of our results.`results` stores + // the lattices corresponding to the results of op, We use a loop to traverse + // them. + for (MetadataLatticeValueLattice *result : results) { + + // `isChanged` records whether the result has been changed. + ChangeResult isChanged = ChangeResult::NoChange; + + // Op's metadata is joined result's lattice. + isChanged |= result->join(latticeValue); + + // All lattice of operands of op are joined to the lattice of result. + for (const MetadataLatticeValueLattice *operand : operands) + isChanged |= result->join(*operand); + + propagateIfChanged(result, isChanged); + } + return success(); +} + +/// At an entry point, We leave its function body empty because no metadata can +/// be joined to Lattice. +void MetadataAnalysis::setToEntryState(MetadataLatticeValueLattice *lattice) {} +} // namespace dataflow +} // namespace mlir diff --git a/mlir/examples/dataflow/lib/Analysis/StringConstantPropagation/CMakeLists.txt b/mlir/examples/dataflow/lib/Analysis/StringConstantPropagation/CMakeLists.txt new file mode 100644 index 0000000000000..1d2afe7ae281a --- /dev/null +++ b/mlir/examples/dataflow/lib/Analysis/StringConstantPropagation/CMakeLists.txt @@ -0,0 +1,6 @@ +add_mlir_example_library(MLIRStringConstantPropagation + StringConstantPropagation.cpp + + LINK_LIBS PUBLIC + MLIRAnalysis + ) diff --git a/mlir/examples/dataflow/lib/Analysis/StringConstantPropagation/StringConstantPropagation.cpp b/mlir/examples/dataflow/lib/Analysis/StringConstantPropagation/StringConstantPropagation.cpp new file mode 100644 index 0000000000000..7108f21a15183 --- /dev/null +++ b/mlir/examples/dataflow/lib/Analysis/StringConstantPropagation/StringConstantPropagation.cpp @@ -0,0 +1,102 @@ +//===-- StringConstantPropagation.cpp - dataflow tutorial -------*- c++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file is contains the implementations of the visit and initialize +// methods for StringConstantPropagation. +// +//===----------------------------------------------------------------------===// + +#include "StringConstantPropagation.h" +#include "StringDialect.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/DebugLog.h" + +#define DEBUG_TYPE "string-constant-propagation" + +using namespace mlir; +using namespace dataflow; + +LogicalResult StringConstantPropagation::visit(ProgramPoint *point) { + LDBG() << "Visiting program point: " << point << " " << *point; + // This function expects only to receive operations. + auto *op = point->getPrevOp(); + + // Get or create the constant string values of the operands. + SmallVector operandValues; + for (Value operand : op->getOperands()) { + auto *value = getOrCreate(operand); + // Create a dependency from the state to this analysis. When the string + // value of one of the operation's operands are updated, invoke the + // transfer function again. + addDependency(value, point); + // If the state is uninitialized, bail out and come back later when it is + // initialized. + if (value->isUninitialized()) + return success(); + operandValues.push_back(value); + } + + // Try to compute a constant value of the result. + auto *result = getOrCreate(op->getResult(0)); + if (auto constant = dyn_cast(op)) { + // Just grab and set the constant value of the result of the operation. + // Propagate an update to the state if it changed. + propagateIfChanged(result, result->join(constant.getValue())); + } else if (auto concat = dyn_cast(op)) { + StringRef lhs = operandValues[0]->getStringValue(); + StringRef rhs = operandValues[1]->getStringValue(); + // If either operand is overdefined, the results are overdefined. + if (lhs.empty() || rhs.empty()) { + propagateIfChanged(result, result->defaultInitialize()); + + // Otherwise, compute the constant value and join it with the result. + } else { + propagateIfChanged(result, result->join(lhs + rhs)); + } + } else { + // We don't know how to implement the transfer function for this + // operation. Mark its results as overdefined. + propagateIfChanged(result, result->defaultInitialize()); + } + return success(); +} + +LogicalResult StringConstantPropagation::initialize(Operation *top) { + LDBG() << "Initializing DeadCodeAnalysis for top-level op: " + << top->getName(); + // Visit every nested string operation and set up its dependencies. + top->walk([&](Operation *op) { + for (Value operand : op->getOperands()) { + auto *state = getOrCreate(operand); + addDependency(state, getProgramPointAfter(op)); + } + }); + // Now that the dependency graph has been set up, "seed" the evolution of the + // analysis by marking the constant values of all block arguments as + // overdefined and the results of (non-constant) operations with no operands. + auto defaultInitializeAll = [&](ValueRange values) { + for (Value value : values) { + auto *state = getOrCreate(value); + propagateIfChanged(state, state->defaultInitialize()); + } + }; + top->walk([&](Operation *op) { + for (Region ®ion : op->getRegions()) + for (Block &block : region) + defaultInitializeAll(block.getArguments()); + if (auto constant = dyn_cast(op)) { + auto *result = getOrCreate(constant.getResult()); + propagateIfChanged(result, result->join(constant.getValue())); + } else if (op->getNumOperands() == 0) { + defaultInitializeAll(op->getResults()); + } + }); + // The dependency graph has been set up and the analysis has been seeded. + // Finish initialization and let the solver run. + return success(); +} diff --git a/mlir/examples/dataflow/lib/CMakeLists.txt b/mlir/examples/dataflow/lib/CMakeLists.txt new file mode 100644 index 0000000000000..c3e877ffdc512 --- /dev/null +++ b/mlir/examples/dataflow/lib/CMakeLists.txt @@ -0,0 +1,3 @@ +add_subdirectory(Analysis) +add_subdirectory(Dialect) +add_subdirectory(Pass) diff --git a/mlir/examples/dataflow/lib/Dialect/CMakeLists.txt b/mlir/examples/dataflow/lib/Dialect/CMakeLists.txt new file mode 100644 index 0000000000000..8b71f6c30bbd7 --- /dev/null +++ b/mlir/examples/dataflow/lib/Dialect/CMakeLists.txt @@ -0,0 +1,10 @@ +add_mlir_dialect_library(MLIRStringDialect + StringDialect.cpp + + DEPENDS + MLIRStringOpsIncGen + + LINK_LIBS PUBLIC + MLIRIR + MLIRLLVMDialect +) diff --git a/mlir/examples/dataflow/lib/Dialect/StringDialect.cpp b/mlir/examples/dataflow/lib/Dialect/StringDialect.cpp new file mode 100644 index 0000000000000..c5da828088c3f --- /dev/null +++ b/mlir/examples/dataflow/lib/Dialect/StringDialect.cpp @@ -0,0 +1,39 @@ +//===- StringDialect.cpp - String ops implementation ----------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the String dialect and its operations. +// +//===----------------------------------------------------------------------===// + +#include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/OpImplementation.h" +#include "llvm/ADT/TypeSwitch.h" + +#include "StringDialect.h" + +using namespace mlir; +using namespace string; + +#include "StringOpsDialect.cpp.inc" + +void StringDialect::initialize() { + addTypes< +#define GET_TYPEDEF_LIST +#include "StringOpsTypes.cpp.inc" + >(); + addOperations< +#define GET_OP_LIST +#include "StringOps.cpp.inc" + >(); +} + +#define GET_OP_CLASSES +#include "StringOps.cpp.inc" + +#define GET_TYPEDEF_CLASSES +#include "StringOpsTypes.cpp.inc" diff --git a/mlir/examples/dataflow/lib/Pass/CMakeLists.txt b/mlir/examples/dataflow/lib/Pass/CMakeLists.txt new file mode 100644 index 0000000000000..0af397265a7fe --- /dev/null +++ b/mlir/examples/dataflow/lib/Pass/CMakeLists.txt @@ -0,0 +1,2 @@ +add_subdirectory(TestStringConstantPropagation) +add_subdirectory(TestMetadataAnalsys) diff --git a/mlir/examples/dataflow/lib/Pass/TestMetadataAnalsys/CMakeLists.txt b/mlir/examples/dataflow/lib/Pass/TestMetadataAnalsys/CMakeLists.txt new file mode 100644 index 0000000000000..267decb716c4b --- /dev/null +++ b/mlir/examples/dataflow/lib/Pass/TestMetadataAnalsys/CMakeLists.txt @@ -0,0 +1,6 @@ +add_mlir_example_library(MLIRTestMetadataAnalysisPass + TestMetadataAnalsys.cpp + + LINK_LIBS PUBLIC + MLIRMetadataAnalysis +) diff --git a/mlir/examples/dataflow/lib/Pass/TestMetadataAnalsys/TestMetadataAnalsys.cpp b/mlir/examples/dataflow/lib/Pass/TestMetadataAnalsys/TestMetadataAnalsys.cpp new file mode 100644 index 0000000000000..ca059fea3f173 --- /dev/null +++ b/mlir/examples/dataflow/lib/Pass/TestMetadataAnalsys/TestMetadataAnalsys.cpp @@ -0,0 +1,59 @@ +//===-- TestMetadataAnalsys.cpp - dataflow tutorial -------------*- c++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file is contains the implementation of TestMetadataAnalysisPass. +// +//===----------------------------------------------------------------------===// + +#include "MetadataAnalysis.h" +#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/Pass/Pass.h" + +using namespace mlir; +using namespace dataflow; + +namespace mlir { +namespace { +class TestMetadataAnalysisPass + : public PassWrapper> { +public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMetadataAnalysisPass) + StringRef getArgument() const final { return "test-metadata-analysis"; } + StringRef getDescription() const final { return "Tests metadata analysis"; } + TestMetadataAnalysisPass() = default; + TestMetadataAnalysisPass(const TestMetadataAnalysisPass &) {} + void runOnOperation() override { + Operation *op = getOperation(); + DataFlowSolver solver; + solver.load(); + solver.load(); + if (failed(solver.initializeAndRun(op))) + return signalPassFailure(); + + // If an op has more than one result, then the lattice is the same for each + // result, and we just print one of the results. + op->walk([&](Operation *op) { + if (op->getNumResults()) { + Value result = op->getResult(0); + auto lattice = solver.lookupState(result); + llvm::outs() << OpWithFlags(op, OpPrintingFlags().skipRegions()) + << " : "; + lattice->print(llvm::outs()); + } + }); + } +}; +} // namespace + +namespace test { +void registerTestMetadataAnalysisPass() { + PassRegistration(); +} +} // namespace test +} // namespace mlir diff --git a/mlir/examples/dataflow/lib/Pass/TestStringConstantPropagation/CMakeLists.txt b/mlir/examples/dataflow/lib/Pass/TestStringConstantPropagation/CMakeLists.txt new file mode 100644 index 0000000000000..dfa438db56207 --- /dev/null +++ b/mlir/examples/dataflow/lib/Pass/TestStringConstantPropagation/CMakeLists.txt @@ -0,0 +1,6 @@ +add_mlir_example_library(MLIRTestStringConstantPropagation + TestStringConstantPropagation.cpp + + LINK_LIBS PUBLIC + MLIRStringConstantPropagation +) diff --git a/mlir/examples/dataflow/lib/Pass/TestStringConstantPropagation/TestStringConstantPropagation.cpp b/mlir/examples/dataflow/lib/Pass/TestStringConstantPropagation/TestStringConstantPropagation.cpp new file mode 100644 index 0000000000000..0a58b8d716093 --- /dev/null +++ b/mlir/examples/dataflow/lib/Pass/TestStringConstantPropagation/TestStringConstantPropagation.cpp @@ -0,0 +1,61 @@ +//===-- TestStringConstantPropagation.cpp - dataflow tutorial ---*- c++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file is contains the implementation of TestStringConstantPropagation. +// +//===----------------------------------------------------------------------===// + +#include "StringConstantPropagation.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/Pass/Pass.h" + +using namespace mlir; +using namespace dataflow; + +namespace mlir { +namespace { +class TestStringConstantPropagation + : public PassWrapper> { +public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestStringConstantPropagation) + StringRef getArgument() const final { + return "test-string-constant-propagation"; + } + StringRef getDescription() const final { + return "Tests string constant propagation"; + } + TestStringConstantPropagation() = default; + TestStringConstantPropagation(const TestStringConstantPropagation &) {} + void runOnOperation() override { + Operation *op = getOperation(); + DataFlowSolver solver; + // Load the analysis. + solver.load(); + if (failed(solver.initializeAndRun(op))) + return signalPassFailure(); + + // Query the results and do something... + op->walk([&](Operation *op) { + if (op->getNumResults()) { + Value result = op->getResult(0); + auto stringConstant = solver.lookupState(result); + llvm::outs() << OpWithFlags(op, OpPrintingFlags().skipRegions()) + << " : "; + stringConstant->print(llvm::outs()); + } + }); + } +}; +} // namespace + +namespace test { +void registerTestStringConstantPropagation() { + PassRegistration(); +} +} // namespace test +} // namespace mlir diff --git a/mlir/test/CMakeLists.txt b/mlir/test/CMakeLists.txt index 89568e7766ae5..962a79c7038cf 100644 --- a/mlir/test/CMakeLists.txt +++ b/mlir/test/CMakeLists.txt @@ -192,6 +192,9 @@ if(LLVM_BUILD_EXAMPLES) mlir-minimal-opt mlir-transform-opt ) + list(APPEND MLIR_TEST_DEPENDS + dataflow-opt + ) if(MLIR_ENABLE_EXECUTION_ENGINE) list(APPEND MLIR_TEST_DEPENDS toyc-ch6 diff --git a/mlir/test/Examples/dataflow/metadata-analysis.mlir b/mlir/test/Examples/dataflow/metadata-analysis.mlir new file mode 100644 index 0000000000000..4c9cc4ce381ae --- /dev/null +++ b/mlir/test/Examples/dataflow/metadata-analysis.mlir @@ -0,0 +1,42 @@ +// RUN: dataflow-opt %s -test-metadata-analysis -split-input-file | FileCheck %s + +// CHECK: {{.*}} = arith.addi {{.*}}, {{.*}} {metadata = {likes_pizza = true}} : index +// CHECK: {{"likes_pizza": true}} +// CHECK-NEXT: {{.*}} = arith.addi {{.*}}, {{.*}} : index +// CHECK: {{"likes_pizza": true}} +func.func @single_join(%arg0 : index, %arg1 : index) -> index { + %1 = arith.addi %arg0, %arg1 {metadata = { likes_pizza = true }} : index + %2 = arith.addi %1, %arg1 : index + return %2 : index +} + +// ----- + +// CHECK: {{.*}} = arith.addi {{.*}}, {{.*}} {metadata = {likes_pizza = true}} : index +// CHECK: {{"likes_pizza": true}} +// CHECK-NEXT: {{.*}} = arith.addi {{.*}}, {{.*}} {metadata = {likes_hotdog = true}} : index +// CHECK: {{"likes_hotdog": true}} +// CHECK-NEXT: {{.*}} = arith.addi {{.*}}, {{.*}} : index +// CHECK: {{"likes_hotdog": true, "likes_pizza": true}} +func.func @muti_join(%arg0 : index, %arg1 : index) -> index { + %1 = arith.addi %arg0, %arg1 {metadata = { likes_pizza = true }} : index + %2 = arith.addi %arg0, %arg1 {metadata = { likes_hotdog = true }} : index + %3 = arith.addi %1, %2 : index + return %3 : index +} + +// ----- + +// CHECK: {{.*}} = arith.addi {{.*}}, {{.*}} {metadata = {likes_pizza = true}} : index +// CHECK: {{"likes_pizza": true}} +// CHECK-NEXT: {{.*}} = arith.addi {{.*}}, {{.*}} {metadata = {likes_pizza = false}} : index +// CHECK: {{"likes_pizza": false}} +// CHECK-NEXT: {{.*}} = arith.addi {{.*}}, {{.*}} : index +// CHECK: {{"likes_pizza": true}} + +func.func @conflict_join(%arg0 : index, %arg1 : index) -> index { + %1 = arith.addi %arg0, %arg1 {metadata = { likes_pizza = true }} : index + %2 = arith.addi %arg0, %arg1 {metadata = { likes_pizza = false }} : index + %3 = arith.addi %1, %2 : index + return %3 : index +} diff --git a/mlir/test/Examples/dataflow/string-constant-propagation.mlir b/mlir/test/Examples/dataflow/string-constant-propagation.mlir new file mode 100644 index 0000000000000..b5d5a80a37d15 --- /dev/null +++ b/mlir/test/Examples/dataflow/string-constant-propagation.mlir @@ -0,0 +1,36 @@ +// RUN: dataflow-opt %s -test-string-constant-propagation -split-input-file | FileCheck %s + +// CHECK: {{.*}} = string.constant "hello " +// CHECK: {{hello}} +// CHECK-NEXT: {{.*}} = string.constant "world." +// CHECK: {{world}} +// CHECK-NEXT: {{.*}} = string.concat {{.*}}, {{.*}} : +// CHECK: {{hello world.}} +func.func @single_concat() { + %1 = string.constant "hello " + %2 = string.constant "world." + %3 = string.concat %1, %2 + return +} + +// ----- + +// CHECK: {{.*}} = string.constant "data" +// CHECK: {{data}} +// CHECK-NEXT: {{.*}} = string.constant "flow " +// CHECK: {{flow}} +// CHECK-NEXT: {{.*}} = string.constant "tutorial" +// CHECK: {{tutorial}} +// CHECK-NEXT: {{.*}} = string.concat {{.*}}, {{.*}} : +// CHECK: {{dataflow}} +// CHECK-NEXT: {{.*}} = string.concat {{.*}}, {{.*}} : +// CHECK: {{dataflow tutorial}} + +func.func @mult_concat() { + %1 = string.constant "data" + %2 = string.constant "flow " + %3 = string.constant "tutorial" + %4 = string.concat %1, %2 + %5 = string.concat %4, %3 + return +}