Skip to content

Commit a1de011

Browse files
committed
SparseAnalysis: support ReturnLike terminators
This PR adds support in sparse analysis for non-control flow region-bearing ops that have return-like terminators. By default it propagates the terminator's operand lattices to the containing op's result lattices, and also allows the analysis subclass to override this behavior.
1 parent 5aa3171 commit a1de011

File tree

7 files changed

+263
-20
lines changed

7 files changed

+263
-20
lines changed

mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,14 @@ class AbstractSparseForwardDataFlowAnalysis : public DataFlowAnalysis {
220220
Operation *op, const RegionSuccessor &successor,
221221
ArrayRef<AbstractSparseLattice *> argLattices, unsigned firstIndex) = 0;
222222

223+
/// Visit a region terminator. This is intended for non-control-flow
224+
/// region-bearing ops whose terminators determine the lattice values of the
225+
/// parent op's results.
226+
virtual LogicalResult visitNonControlFlowTerminatorImpl(
227+
Operation *terminatorOp,
228+
ArrayRef<const AbstractSparseLattice *> terminatorOperandLattices,
229+
ArrayRef<AbstractSparseLattice *> parentResultLattices) = 0;
230+
223231
/// Get the lattice element of a value.
224232
virtual AbstractSparseLattice *getLatticeElement(Value value) = 0;
225233

@@ -235,6 +243,29 @@ class AbstractSparseForwardDataFlowAnalysis : public DataFlowAnalysis {
235243
/// Join the lattice element and propagate and update if it changed.
236244
void join(AbstractSparseLattice *lhs, const AbstractSparseLattice &rhs);
237245

246+
// Get the lattice elements of the operands.
247+
SmallVector<const AbstractSparseLattice *> getOperandLattices(Operation *op) {
248+
SmallVector<const AbstractSparseLattice *> operandLattices;
249+
operandLattices.reserve(op->getNumOperands());
250+
for (Value operand : op->getOperands()) {
251+
AbstractSparseLattice *operandLattice = getLatticeElement(operand);
252+
operandLattice->useDefSubscribe(this);
253+
operandLattices.push_back(operandLattice);
254+
}
255+
return operandLattices;
256+
}
257+
258+
// Get the lattice elements of the results.
259+
SmallVector<AbstractSparseLattice *> getResultLattices(Operation *op) {
260+
SmallVector<AbstractSparseLattice *> resultLattices;
261+
resultLattices.reserve(op->getNumResults());
262+
for (Value result : op->getResults()) {
263+
AbstractSparseLattice *resultLattice = getLatticeElement(result);
264+
resultLattices.push_back(resultLattice);
265+
}
266+
return resultLattices;
267+
}
268+
238269
private:
239270
/// Recursively initialize the analysis on nested operations and blocks.
240271
LogicalResult initializeRecursively(Operation *op);
@@ -299,6 +330,28 @@ class SparseForwardDataFlowAnalysis
299330
setAllToEntryStates(resultLattices);
300331
}
301332

333+
/// Visit a region terminator. This is intended for non-control-flow
334+
/// region-bearing ops whose terminators determine the lattice values of the
335+
/// parent op's results. By default the terminator's operand lattices are
336+
/// forwarded to the parent result lattices, if there is a 1-1
337+
/// correspondence.
338+
virtual LogicalResult visitNonControlFlowTerminator(
339+
Operation *terminatorOp,
340+
ArrayRef<const StateT *> terminatorOperandLattices,
341+
ArrayRef<StateT *> parentResultLattices) {
342+
// ReturnLike terminators forward their lattice values to the results of the
343+
// parent op.
344+
if (terminatorOp->hasTrait<OpTrait::ReturnLike>() &&
345+
terminatorOperandLattices.size() == parentResultLattices.size()) {
346+
for (const auto &[operandLattice, resultLattice] :
347+
llvm::zip(terminatorOperandLattices, parentResultLattices)) {
348+
propagateIfChanged(resultLattice, resultLattice->join(*operandLattice));
349+
}
350+
}
351+
352+
return success();
353+
}
354+
302355
/// Given an operation with possible region control-flow, the lattices of the
303356
/// operands, and a region successor, compute the lattice values for block
304357
/// arguments that are not accounted for by the branching control flow (ex.
@@ -370,6 +423,18 @@ class SparseForwardDataFlowAnalysis
370423
argLattices.size()},
371424
firstIndex);
372425
}
426+
LogicalResult visitNonControlFlowTerminatorImpl(
427+
Operation *terminatorOp,
428+
ArrayRef<const AbstractSparseLattice *> terminatorOperandLattices,
429+
ArrayRef<AbstractSparseLattice *> parentResultLattices) override {
430+
return visitNonControlFlowTerminator(
431+
terminatorOp,
432+
{reinterpret_cast<const StateT *const *>(
433+
terminatorOperandLattices.begin()),
434+
terminatorOperandLattices.size()},
435+
{reinterpret_cast<StateT *const *>(parentResultLattices.begin()),
436+
parentResultLattices.size()});
437+
}
373438
void setToEntryState(AbstractSparseLattice *lattice) override {
374439
return setToEntryState(reinterpret_cast<StateT *>(lattice));
375440
}

mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp

Lines changed: 25 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,10 @@
77
//===----------------------------------------------------------------------===//
88

99
#include "mlir/Analysis/DataFlow/SparseAnalysis.h"
10+
11+
#include <cassert>
12+
#include <optional>
13+
1014
#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
1115
#include "mlir/Analysis/DataFlowFramework.h"
1216
#include "mlir/IR/Attributes.h"
@@ -20,8 +24,6 @@
2024
#include "mlir/Support/LLVM.h"
2125
#include "llvm/ADT/STLExtras.h"
2226
#include "llvm/Support/Casting.h"
23-
#include <cassert>
24-
#include <optional>
2527

2628
using namespace mlir;
2729
using namespace mlir::dataflow;
@@ -94,23 +96,32 @@ AbstractSparseForwardDataFlowAnalysis::visit(ProgramPoint *point) {
9496

9597
LogicalResult
9698
AbstractSparseForwardDataFlowAnalysis::visitOperation(Operation *op) {
97-
// Exit early on operations with no results.
98-
if (op->getNumResults() == 0)
99-
return success();
100-
10199
// If the containing block is not executable, bail out.
102100
if (op->getBlock() != nullptr &&
103101
!getOrCreate<Executable>(getProgramPointBefore(op->getBlock()))->isLive())
104102
return success();
105103

106-
// Get the result lattices.
107-
SmallVector<AbstractSparseLattice *> resultLattices;
108-
resultLattices.reserve(op->getNumResults());
109-
for (Value result : op->getResults()) {
110-
AbstractSparseLattice *resultLattice = getLatticeElement(result);
111-
resultLattices.push_back(resultLattice);
104+
// Region terminators which are not part of control flow have a special
105+
// transfer function.
106+
if (op->hasTrait<OpTrait::IsTerminator>()) {
107+
Operation *parentOp = op->getParentOp();
108+
if (parentOp && !isa<RegionBranchOpInterface>(parentOp) &&
109+
!isa<RegionBranchTerminatorOpInterface>(op) &&
110+
parentOp->getNumResults() > 0) {
111+
SmallVector<const AbstractSparseLattice *> operandLattices =
112+
getOperandLattices(op);
113+
SmallVector<AbstractSparseLattice *> parentResultLattices =
114+
getResultLattices(parentOp);
115+
return visitNonControlFlowTerminatorImpl(op, operandLattices,
116+
parentResultLattices);
117+
}
112118
}
113119

120+
if (op->getNumResults() == 0)
121+
return success();
122+
123+
SmallVector<AbstractSparseLattice *> resultLattices = getResultLattices(op);
124+
114125
// The results of a region branch operation are determined by control-flow.
115126
if (auto branch = dyn_cast<RegionBranchOpInterface>(op)) {
116127
visitRegionSuccessors(getProgramPointAfter(branch), branch,
@@ -119,14 +130,8 @@ AbstractSparseForwardDataFlowAnalysis::visitOperation(Operation *op) {
119130
return success();
120131
}
121132

122-
// Grab the lattice elements of the operands.
123-
SmallVector<const AbstractSparseLattice *> operandLattices;
124-
operandLattices.reserve(op->getNumOperands());
125-
for (Value operand : op->getOperands()) {
126-
AbstractSparseLattice *operandLattice = getLatticeElement(operand);
127-
operandLattice->useDefSubscribe(this);
128-
operandLattices.push_back(operandLattice);
129-
}
133+
SmallVector<const AbstractSparseLattice *> operandLattices =
134+
getOperandLattices(op);
130135

131136
if (auto call = dyn_cast<CallOpInterface>(op)) {
132137
// If the call operation is to an external function, attempt to infer the
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
// RUN: mlir-opt --test-integer-lattice %s | FileCheck %s
2+
3+
// CHECK-LABEL: @test_returnlike
4+
// CHECK: analysis_return_like_region_op
5+
// CHECK-NEXT: arith.constant {test.operand_lattices = [], test.result_lattices = [0 : index]} 1 : i32
6+
// CHECK-NEXT: region_yield
7+
// CHECK-SAME: {test.operand_lattices = [0 : index], test.result_lattices = []}
8+
9+
// The core of the return-like test: the operand lattices of the yield forward
10+
// to the result lattices of the enclosing region-holding op
11+
12+
// CHECK-NEXT: }) {test.operand_lattices = [], test.result_lattices = [0 : index]} : () -> i32
13+
func.func @test_returnlike() {
14+
%0 = "test.analysis_return_like_region_op"() ({
15+
%0 = arith.constant 1 : i32
16+
"test.region_yield" (%0) : (i32) -> ()
17+
}) : () -> i32
18+
return
19+
}

mlir/test/lib/Analysis/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ add_mlir_library(MLIRTestAnalysis
1717
DataFlow/TestDenseForwardDataFlowAnalysis.cpp
1818
DataFlow/TestLivenessAnalysis.cpp
1919
DataFlow/TestSparseBackwardDataFlowAnalysis.cpp
20+
DataFlow/TestSparseForwardDataFlowAnalysis.cpp
2021

2122
EXCLUDE_FROM_LIBMLIR
2223

Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
//===- TestForwardDataFlowAnalysis.cpp - Test dead code analysis ---------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
10+
#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
11+
#include "mlir/Analysis/DataFlow/SparseAnalysis.h"
12+
#include "mlir/Dialect/MemRef/IR/MemRef.h"
13+
#include "mlir/Interfaces/SideEffectInterfaces.h"
14+
#include "mlir/Pass/Pass.h"
15+
16+
using namespace mlir;
17+
using namespace mlir::dataflow;
18+
19+
namespace {
20+
21+
class IntegerState {
22+
public:
23+
IntegerState() : value(0) {}
24+
explicit IntegerState(int value) : value(value) {}
25+
~IntegerState() = default;
26+
27+
int get() const { return value; }
28+
29+
bool operator==(const IntegerState &rhs) const { return value == rhs.value; }
30+
31+
static IntegerState join(const IntegerState &lhs, const IntegerState &rhs) {
32+
return IntegerState{std::max(lhs.get(), rhs.get())};
33+
}
34+
35+
void print(llvm::raw_ostream &os) const {
36+
os << "IntegerState(" << value << ")";
37+
}
38+
39+
friend llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
40+
const IntegerState &state) {
41+
state.print(os);
42+
return os;
43+
}
44+
45+
private:
46+
int value;
47+
};
48+
49+
/// This lattice represents, for a given value, the set of memory resources that
50+
/// this value, or anything derived from this value, is potentially written to.
51+
struct IntegerLattice : public Lattice<IntegerState> {
52+
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(IntegerLattice)
53+
using Lattice::Lattice;
54+
};
55+
56+
/// An analysis that, by going backwards along the dataflow graph, annotates
57+
/// each value with all the memory resources it (or anything derived from it)
58+
/// is eventually written to.
59+
class IntegerLatticeAnalysis
60+
: public SparseForwardDataFlowAnalysis<IntegerLattice> {
61+
public:
62+
using SparseForwardDataFlowAnalysis::SparseForwardDataFlowAnalysis;
63+
64+
LogicalResult visitOperation(Operation *op,
65+
ArrayRef<const IntegerLattice *> operands,
66+
ArrayRef<IntegerLattice *> results) override;
67+
68+
void setToEntryState(IntegerLattice *lattice) override {
69+
propagateIfChanged(lattice, lattice->join(IntegerState()));
70+
}
71+
};
72+
73+
LogicalResult IntegerLatticeAnalysis::visitOperation(
74+
Operation *op, ArrayRef<const IntegerLattice *> operands,
75+
ArrayRef<IntegerLattice *> results) {
76+
for (auto *operand : operands) {
77+
for (auto *result : results) {
78+
propagateIfChanged(result, result->join(*operand));
79+
}
80+
}
81+
return success();
82+
}
83+
84+
} // end anonymous namespace
85+
86+
namespace {
87+
struct TestIntegerLatticePass
88+
: public PassWrapper<TestIntegerLatticePass, OperationPass<>> {
89+
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestIntegerLatticePass)
90+
91+
TestIntegerLatticePass() = default;
92+
TestIntegerLatticePass(const TestIntegerLatticePass &other)
93+
: PassWrapper(other) {}
94+
95+
StringRef getArgument() const override { return "test-integer-lattice"; }
96+
97+
void runOnOperation() override {
98+
Operation *op = getOperation();
99+
MLIRContext *ctx = &getContext();
100+
101+
DataFlowSolver solver;
102+
solver.load<DeadCodeAnalysis>();
103+
solver.load<SparseConstantPropagation>();
104+
solver.load<IntegerLatticeAnalysis>();
105+
if (failed(solver.initializeAndRun(op)))
106+
return signalPassFailure();
107+
108+
// Walk the IR and attach operand and result lattices as attributes to each
109+
// operation.
110+
op->walk([&](Operation *op) {
111+
SmallVector<Attribute> operandAttrs;
112+
SmallVector<Attribute> resultAttrs;
113+
for (auto [index, operand] : llvm::enumerate(op->getOperands())) {
114+
const IntegerLattice *lattice =
115+
solver.lookupState<IntegerLattice>(operand);
116+
assert(lattice && "expected a sparse lattice");
117+
operandAttrs.push_back(
118+
IntegerAttr::get(IndexType::get(ctx), lattice->getValue().get()));
119+
}
120+
for (auto [index, result] : llvm::enumerate(op->getResults())) {
121+
const IntegerLattice *lattice =
122+
solver.lookupState<IntegerLattice>(result);
123+
assert(lattice && "expected a sparse lattice");
124+
resultAttrs.push_back(
125+
IntegerAttr::get(IndexType::get(ctx), lattice->getValue().get()));
126+
}
127+
128+
op->setAttr("test.operand_lattices", ArrayAttr::get(ctx, operandAttrs));
129+
op->setAttr("test.result_lattices", ArrayAttr::get(ctx, resultAttrs));
130+
});
131+
}
132+
};
133+
} // end anonymous namespace
134+
135+
namespace mlir {
136+
namespace test {
137+
void registerTestIntegerLatticePass() {
138+
PassRegistration<TestIntegerLatticePass>();
139+
}
140+
} // end namespace test
141+
} // end namespace mlir

mlir/test/lib/Dialect/Test/TestOps.td

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3507,4 +3507,14 @@ def TestAllocWithMultipleResults : TEST_Op<"alloc_with_multiple_results"> {
35073507
}];
35083508
}
35093509

3510+
// ==------------------------------------------------------------------------===//
3511+
// Test Analysis ReturnLike
3512+
//===----------------------------------------------------------------------===//
3513+
def AnalysisReturnLikeRegionOp : TEST_Op<"analysis_return_like_region_op",
3514+
[SingleBlockImplicitTerminator<"RegionYieldOp">]> {
3515+
let regions = (region AnyRegion:$region);
3516+
let results = (outs AnyType:$result);
3517+
}
3518+
3519+
35103520
#endif // TEST_OPS

mlir/tools/mlir-opt/mlir-opt.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ void registerTestComposeSubView();
105105
void registerTestMultiBuffering();
106106
void registerTestIRVisitorsPass();
107107
void registerTestGenericIRVisitorsPass();
108+
void registerTestIntegerLatticePass();
108109
void registerTestInterfaces();
109110
void registerTestIRVisitorsPass();
110111
void registerTestLastModifiedPass();
@@ -249,6 +250,7 @@ void registerTestPasses() {
249250
mlir::test::registerTestMultiBuffering();
250251
mlir::test::registerTestIRVisitorsPass();
251252
mlir::test::registerTestGenericIRVisitorsPass();
253+
mlir::test::registerTestIntegerLatticePass();
252254
mlir::test::registerTestInterfaces();
253255
mlir::test::registerTestIRVisitorsPass();
254256
mlir::test::registerTestLastModifiedPass();

0 commit comments

Comments
 (0)