Skip to content

Commit e00eadf

Browse files
pifon2amemfrob
authored andcommitted
[mlir][nfc] Move getInnermostParallelLoops to SCF/Transforms/Utils.h.
1 parent 5346b95 commit e00eadf

File tree

3 files changed

+31
-24
lines changed

3 files changed

+31
-24
lines changed

mlir/include/mlir/Dialect/SCF/Utils.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
namespace mlir {
1919
class FuncOp;
20+
class Operation;
2021
class OpBuilder;
2122
class ValueRange;
2223

@@ -57,5 +58,11 @@ scf::ForOp cloneWithNewYields(OpBuilder &b, scf::ForOp loop,
5758
/// region is inlined into a new FuncOp that is captured by the pointer.
5859
void outlineIfOp(OpBuilder &b, scf::IfOp ifOp, FuncOp *thenFn,
5960
StringRef thenFnName, FuncOp *elseFn, StringRef elseFnName);
61+
62+
/// Get a list of innermost parallel loops contained in `rootOp`. Innermost parallel
63+
/// loops are those that do not contain further parallel loops themselves.
64+
bool getInnermostParallelLoops(Operation *rootOp,
65+
SmallVectorImpl<scf::ParallelOp> &result);
66+
6067
} // end namespace mlir
6168
#endif // MLIR_DIALECT_SCF_UTILS_H_

mlir/lib/Dialect/SCF/Transforms/ParallelLoopTiling.cpp

Lines changed: 2 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "mlir/Dialect/SCF/Passes.h"
1616
#include "mlir/Dialect/SCF/SCF.h"
1717
#include "mlir/Dialect/SCF/Transforms.h"
18+
#include "mlir/Dialect/SCF/Utils.h"
1819
#include "mlir/Dialect/StandardOps/IR/Ops.h"
1920

2021
using namespace mlir;
@@ -126,29 +127,6 @@ void mlir::scf::tileParallelLoop(ParallelOp op, ArrayRef<int64_t> tileSizes) {
126127
op.erase();
127128
}
128129

129-
/// Get a list of most nested parallel loops.
130-
static bool getInnermostPloops(Operation *rootOp,
131-
SmallVectorImpl<ParallelOp> &result) {
132-
assert(rootOp != nullptr && "Root operation must not be a nullptr.");
133-
bool rootEnclosesPloops = false;
134-
for (Region &region : rootOp->getRegions()) {
135-
for (Block &block : region.getBlocks()) {
136-
for (Operation &op : block) {
137-
bool enclosesPloops = getInnermostPloops(&op, result);
138-
rootEnclosesPloops |= enclosesPloops;
139-
if (auto ploop = dyn_cast<ParallelOp>(op)) {
140-
rootEnclosesPloops = true;
141-
142-
// Collect ploop if it is an innermost one.
143-
if (!enclosesPloops)
144-
result.push_back(ploop);
145-
}
146-
}
147-
}
148-
}
149-
return rootEnclosesPloops;
150-
}
151-
152130
namespace {
153131
struct ParallelLoopTiling
154132
: public SCFParallelLoopTilingBase<ParallelLoopTiling> {
@@ -159,7 +137,7 @@ struct ParallelLoopTiling
159137

160138
void runOnFunction() override {
161139
SmallVector<ParallelOp, 2> innermostPloops;
162-
getInnermostPloops(getFunction().getOperation(), innermostPloops);
140+
getInnermostParallelLoops(getFunction().getOperation(), innermostPloops);
163141
for (ParallelOp ploop : innermostPloops) {
164142
// FIXME: Add reduction support.
165143
if (ploop.getNumReductions() == 0)

mlir/lib/Dialect/SCF/Transforms/Utils.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,3 +123,25 @@ void mlir::outlineIfOp(OpBuilder &b, scf::IfOp ifOp, FuncOp *thenFn,
123123
if (elseFn && !ifOp.elseRegion().empty())
124124
*elseFn = outline(ifOp.elseRegion(), elseFnName);
125125
}
126+
127+
bool mlir::getInnermostParallelLoops(Operation *rootOp,
128+
SmallVectorImpl<scf::ParallelOp> &result) {
129+
assert(rootOp != nullptr && "Root operation must not be a nullptr.");
130+
bool rootEnclosesPloops = false;
131+
for (Region &region : rootOp->getRegions()) {
132+
for (Block &block : region.getBlocks()) {
133+
for (Operation &op : block) {
134+
bool enclosesPloops = getInnermostParallelLoops(&op, result);
135+
rootEnclosesPloops |= enclosesPloops;
136+
if (auto ploop = dyn_cast<scf::ParallelOp>(op)) {
137+
rootEnclosesPloops = true;
138+
139+
// Collect parallel loop if it is an innermost one.
140+
if (!enclosesPloops)
141+
result.push_back(ploop);
142+
}
143+
}
144+
}
145+
}
146+
return rootEnclosesPloops;
147+
}

0 commit comments

Comments
 (0)