Skip to content
Open
270 changes: 148 additions & 122 deletions mlir/docs/Tutorials/DataFlowAnalysis.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ constructs, of which MLIR has many (Block-based branches, Region-based branches,
CallGraph, etc), and it isn't always clear how best to go about performing the
propagation. To help writing these types of analyses in MLIR, this document
details several utilities that simplify the process and make it a bit more
approachable.
approachable. The code from this tutorial can be found in `mlir/examples/dataflow`.

## Forward Dataflow Analysis

Expand Down Expand Up @@ -72,31 +72,11 @@ held by an element of the lattice used by our dataflow analysis:
struct MetadataLatticeValue {
MetadataLatticeValue() = default;
/// Compute a lattice value from the provided dictionary.
MetadataLatticeValue(DictionaryAttr attr)
: metadata(attr.begin(), attr.end()) {}

/// Return a pessimistic value state, i.e. the `top`/`overdefined`/`unknown`
/// state, for our value type. The resultant state should not assume any
/// information about the state of the IR.
static MetadataLatticeValue getPessimisticValueState(MLIRContext *context) {
// The `top`/`overdefined`/`unknown` state is when we know nothing about any
// metadata, i.e. an empty dictionary.
return MetadataLatticeValue();
}
/// Return a pessimistic value state for our value type using only information
/// about the state of the provided IR. This is similar to the above method,
/// but may produce a slightly more refined result. This is okay, as the
/// information is already encoded as fact in the IR.
static MetadataLatticeValue getPessimisticValueState(Value value) {
// Check to see if the parent operation has metadata.
if (Operation *parentOp = value.getDefiningOp()) {
if (auto metadata = parentOp->getAttrOfType<DictionaryAttr>("metadata"))
return MetadataLatticeValue(metadata);

// If no metadata is present, fallback to the
// `top`/`overdefined`/`unknown` state.
MetadataLatticeValue(DictionaryAttr attr) {
for (NamedAttribute pair : attr) {
metadata.insert(
std::pair<StringAttr, Attribute>(pair.getName(), pair.getValue()));
}
return MetadataLatticeValue();
}

/// This method conservatively joins the information held by `lhs` and `rhs`
Expand All @@ -110,33 +90,48 @@ struct MetadataLatticeValue {
/// * monotonicity: join(x, join(x,y)) == join(x,y)
static MetadataLatticeValue join(const MetadataLatticeValue &lhs,
const MetadataLatticeValue &rhs) {
// To join `lhs` and `rhs` we will define a simple policy, which is that we
// only keep information that is the same. This means that we only keep
// facts that are true in both.
MetadataLatticeValue result;
for (const auto &lhsIt : lhs.metadata) {
// As noted above, we only merge if the values are the same.
auto it = rhs.metadata.find(lhsIt.first);
if (it == rhs.metadata.end() || it.second != lhsIt.second)
continue;
result.insert(lhsIt);
}
return result;
// To join `lhs` and `rhs` we will define a simple policy, which is that we
// directly insert the metadata of rhs into the metadata of lhs.If lhs and rhs
// have overlapping attributes, keep the attribute value in lhs unchanged.
MetadataLatticeValue result;
for (auto &&lhsIt : lhs.metadata) {
result.metadata.insert(
std::pair<StringAttr, Attribute>(lhsIt.first, lhsIt.second));
}

for (auto &&rhsIt : rhs.metadata) {
result.metadata.insert(
std::pair<StringAttr, Attribute>(rhsIt.first, rhsIt.second));
}
return result;
}

/// A simple comparator that checks to see if this value is equal to the one
/// provided.
bool operator==(const MetadataLatticeValue &rhs) const {
if (metadata.size() != rhs.metadata.size())
if (metadata.size() != rhs.metadata.size())
return false;

// Check that `rhs` contains the same metadata.
for (auto &&it : metadata) {
auto rhsIt = rhs.metadata.find(it.first);
if (rhsIt == rhs.metadata.end() || it.second != rhsIt->second)
return false;
// Check that `rhs` contains the same metadata.
for (const auto &it : metadata) {
auto rhsIt = rhs.metadata.find(it.first);
if (rhsIt == rhs.metadata.end() || it.second != rhsIt.second)
return false;
}
return true;
}
return true;
}

/// Print data in metadata.
void print(llvm::raw_ostream &os) const {
SmallVector<StringAttr> metadataKey(metadata.keys());
std::sort(metadataKey.begin(), metadataKey.end(),
[&](StringAttr a, StringAttr b) { return a < b; });
os << "{";
for (StringAttr key : metadataKey) {
os << key << ": " << metadata.at(key) << ", ";
}
os << "\b\b}\n";
}

/// Our value represents the combined metadata, which is originally a
/// DictionaryAttr, so we use a map.
Expand All @@ -154,7 +149,7 @@ shown below:
/// This class represents a lattice element holding a specific value of type
/// `ValueT`.
template <typename ValueT>
class LatticeElement ... {
class Lattice ... {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment says This class represents a lattice element... , why did you rename it?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The class corresponding to the current lattice element in C++ should be Lattice.

public:
/// Return the value held by this element. This requires that a value is
/// known, i.e. not `uninitialized`.
Expand All @@ -168,20 +163,25 @@ public:
/// Join the information contained in the 'rhs' value into this
/// lattice. Returns if the state of the current lattice changed.
ChangeResult join(const ValueT &rhs);

/// Mark the lattice element as having reached a pessimistic fixpoint. This
/// means that the lattice may potentially have conflicting value states, and
/// only the conservatively known value state should be relied on.
ChangeResult markPessimisticFixPoint();

...
};
```

With our lattice defined, we can now define the driver that will compute and
propagate our lattice across the IR.
propagate our lattice across the IR. The following is our definition of metadata
lattice.

```c++
class MetadataLatticeValueLattice : public Lattice<MetadataLatticeValue> {
public:
using Lattice::Lattice;
};
```

### ForwardDataflowAnalysis Driver
### SparseForwardDataFlowAnalysis Driver

The `ForwardDataFlowAnalysis` class represents the driver of the dataflow
The `SparseForwardDataFlowAnalysis` class represents the driver of the dataflow
analysis, and performs all of the related analysis computation. When defining
our analysis, we will inherit from this class and implement some of its hooks.
Before that, let's look at a quick overview of this class and some of the
Expand All @@ -190,42 +190,36 @@ important API for our analysis:
```c++
/// This class represents the main driver of the forward dataflow analysis. It
/// takes as a template parameter the value type of lattice being computed.
template <typename ValueT>
class ForwardDataFlowAnalysis : ... {
template <typename StateT>
class SparseForwardDataFlowAnalysis : ... {
public:
ForwardDataFlowAnalysis(MLIRContext *context);

/// Compute the analysis on operations rooted under the given top-level
/// operation. Note that the top-level operation is not visited.
void run(Operation *topLevelOp);
explicit SparseForwardDataFlowAnalysis(DataFlowSolver &solver)
: AbstractSparseForwardDataFlowAnalysis(solver) {}

/// Visit an operation with the lattices of its operands. This function is
/// expected to set the lattices of the operation's results.
virtual LogicalResult visitOperation(Operation *op,
ArrayRef<const StateT *> operands,
ArrayRef<StateT *> results) = 0;
...

protected:
/// Return the lattice element attached to the given value. If a lattice has
/// not been added for the given value, a new 'uninitialized' value is
/// inserted and returned.
LatticeElement<ValueT> &getLatticeElement(Value value);

/// Return the lattice element attached to the given value, or nullptr if no
/// lattice element for the value has yet been created.
LatticeElement<ValueT> *lookupLatticeElement(Value value);
StateT *getLatticeElement(Value value);

/// Mark all of the lattice elements for the given range of Values as having
/// reached a pessimistic fixpoint.
ChangeResult markAllPessimisticFixPoint(ValueRange values);
/// Get the lattice element for a value and create a dependency on the
/// provided program point.
const StateT *getLatticeElementFor(ProgramPoint *point, Value value);

protected:
/// Visit the given operation, and join any necessary analysis state
/// into the lattice elements for the results and block arguments owned by
/// this operation using the provided set of operand lattice elements
/// (all pointer values are guaranteed to be non-null). Returns if any result
/// or block argument value lattice elements changed during the visit. The
/// lattice element for a result or block argument value can be obtained, and
/// join'ed into, by using `getLatticeElement`.
virtual ChangeResult visitOperation(
Operation *op, ArrayRef<LatticeElement<ValueT> *> operands) = 0;
/// Set the given lattice element(s) at control flow entry point(s).
virtual void setToEntryState(StateT *lattice) = 0;
...
};
```

NOTE: Some API has been redacted for our example. The `ForwardDataFlowAnalysis`
NOTE: Some API has been redacted for our example. The `SparseForwardDataFlowAnalysis`
contains various other hooks that allow for injecting custom behavior when
applicable.

Expand All @@ -237,60 +231,92 @@ function for the operation, that is specific to our analysis. A simple
implementation for our example is shown below:

```c++
class MetadataAnalysis : public ForwardDataFlowAnalysis<MetadataLatticeValue> {
class MetadataAnalysis
: public SparseForwardDataFlowAnalysis<MetadataLatticeValueLattice> {
public:
using ForwardDataFlowAnalysis<MetadataLatticeValue>::ForwardDataFlowAnalysis;

ChangeResult visitOperation(
Operation *op, ArrayRef<LatticeElement<ValueT> *> operands) override {
DictionaryAttr metadata = op->getAttrOfType<DictionaryAttr>("metadata");

// If we have no metadata for this operation, we will conservatively mark
// all of the results as having reached a pessimistic fixpoint.
if (!metadata)
return markAllPessimisticFixPoint(op->getResults());

// Otherwise, we will compute a lattice value for the metadata and join it
// into the current lattice element for all of our results.
MetadataLatticeValue latticeValue(metadata);
ChangeResult result = ChangeResult::NoChange;
for (Value value : op->getResults()) {
// We grab the lattice element for `value` via `getLatticeElement` and
// then join it with the lattice value for this operation's metadata. Note
// that during the analysis phase, it is fine to freely create a new
// lattice element for a value. This is why we don't use the
// `lookupLatticeElement` method here.
result |= getLatticeElement(value).join(latticeValue);
}
return result;
using SparseForwardDataFlowAnalysis::SparseForwardDataFlowAnalysis;
LogicalResult
visitOperation(Operation *op,
ArrayRef<const MetadataLatticeValueLattice *> operands,
ArrayRef<MetadataLatticeValueLattice *> results) override {
DictionaryAttr metadata = op->getAttrOfType<DictionaryAttr>("metadata");
// If we have no metadata for this operation and the operands is empty, we
// will conservatively mark all of the results as having reached a pessimistic
// fixpoint.
if (!metadata && operands.empty()) {
setAllToEntryStates(results);
return success();
}

MetadataLatticeValue latticeValue;
if (metadata)
latticeValue = MetadataLatticeValue(metadata);

// Otherwise, we will compute a lattice value for the metadata and join it
// into the current lattice element for all of our results.`results` stores
// the lattices corresponding to the results of op, We use a loop to traverse
// them.
for (auto result : results) {

// `isChanged` records whether the result has been changed.
ChangeResult isChanged = ChangeResult::NoChange;

// Op's metadata is joined result's lattice.
isChanged |= result->join(latticeValue);

// All lattice of operands of op are joined to the lattice of result.
for (auto operand : operands)
isChanged |= result->join(*operand);

propagateIfChanged(result, isChanged);
}
return success();
}
};
```

With that, we have all of the necessary components to compute our analysis.
After the analysis has been computed, we can grab any computed information for
values by using `lookupLatticeElement`. We use this function over
`getLatticeElement` as the analysis is not guaranteed to visit all values, e.g.
if the value is in a unreachable block, and we don't want to create a new
uninitialized lattice element in this case. See below for a quick example:
After the analysis has been computed, we need to run our analysis using
`DataFlowSolver`, and we can grab any computed information for values by
using `lookupState`. See below for a quick example, after the pass runs the
analysis, we print the metadata of each op's results.

```c++
void MyPass::runOnOperation() {
MetadataAnalysis analysis(&getContext());
analysis.run(getOperation());
Operation *op = getOperation();
DataFlowSolver solver;
solver.load<DeadCodeAnalysis>();
solver.load<MetadataAnalysis>();
if (failed(solver.initializeAndRun(op)))
return signalPassFailure();

// If an op has more than one result, then the lattice is the same for each
// result, and we just print one of the results.
op->walk([&](Operation *op) {
if (op->getNumResults()) {
Value result = op->getResult(0);
auto lattice = solver.lookupState<MetadataLatticeValueLattice>(result);
llvm::outs() << OpWithFlags(op, OpPrintingFlags().skipRegions()) << " : ";
lattice->print(llvm::outs());
}
});
...
}
```

void MyPass::useAnalysisOn(MetadataAnalysis &analysis, Value value) {
LatticeElement<MetadataLatticeValue> *latticeElement = analysis.lookupLatticeElement(value);

// If we don't have an element, the `value` wasn't visited during our analysis
// meaning that it could be dead. We need to treat this conservatively.
if (!lattice)
return;
The following is a simple example. More tests can be found in the `mlir/Example/dataflow`.

// Our lattice element has a value, use it:
MetadataLatticeValue &value = lattice->getValue();
...
```mlir
func.func @single_join(%arg0 : index, %arg1 : index) -> index {
%1 = arith.addi %arg0, %arg1 {metadata = { likes_pizza = true }} : index
%2 = arith.addi %1, %arg1 : index
return %2 : index
}
```

The above IR will print the following after running pass.

```
%0 = arith.addi %arg0, %arg1 {metadata = {likes_pizza = true}} : index : {"likes_pizza": true}
%1 = arith.addi %0, %arg1 : index : {"likes_pizza": true}
```
1 change: 1 addition & 0 deletions mlir/examples/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
add_subdirectory(dataflow)
add_subdirectory(toy)
add_subdirectory(transform)
add_subdirectory(transform-opt)
Expand Down
7 changes: 7 additions & 0 deletions mlir/examples/dataflow/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
add_custom_target(DataFlowExample)
set_target_properties(DataFlowExample PROPERTIES FOLDER "MLIR/Examples")

include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include)

add_subdirectory(lib)
add_subdirectory(dataflow-opt)
11 changes: 11 additions & 0 deletions mlir/examples/dataflow/dataflow-opt/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
add_dependencies(DataFlowExample dataflow-opt)
add_llvm_example(dataflow-opt
dataflow-opt.cpp
)

target_link_libraries(dataflow-opt
PRIVATE
MLIRIR
MLIRMlirOptMain
MLIRTestMetadataAnalysisPass
)
Loading