Skip to content

Commit 3c82411

Browse files
committed
Make getDominatingDataClauses a public utility
1 parent dd1b4ab commit 3c82411

File tree

3 files changed

+86
-62
lines changed

3 files changed

+86
-62
lines changed

mlir/include/mlir/Dialect/OpenACC/OpenACCUtils.h

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,13 @@
1010
#define MLIR_DIALECT_OPENACC_OPENACCUTILS_H_
1111

1212
#include "mlir/Dialect/OpenACC/OpenACC.h"
13+
#include "llvm/ADT/SmallVector.h"
1314

1415
namespace mlir {
16+
class DominanceInfo;
17+
class PostDominanceInfo;
18+
class Value;
19+
class Operation;
1520
namespace acc {
1621

1722
/// Used to obtain the enclosing compute construct operation that contains
@@ -62,6 +67,22 @@ mlir::Value getBaseEntity(mlir::Value val);
6267
bool isValidSymbolUse(mlir::Operation *user, mlir::SymbolRefAttr symbol,
6368
mlir::Operation **definingOpPtr = nullptr);
6469

70+
/// Collects all data clauses that dominate the compute construct.
71+
/// This includes data clauses from:
72+
/// - The compute construct itself
73+
/// - Enclosing data constructs
74+
/// - Applicable declare directives (those that dominate and post-dominate)
75+
/// This is used to determine if a variable is already covered by an existing
76+
/// data clause.
77+
/// \param computeConstructOp The compute construct operation
78+
/// \param domInfo Dominance information
79+
/// \param postDomInfo Post-dominance information
80+
/// \return Vector of data clause values that dominate the compute construct
81+
llvm::SmallVector<mlir::Value>
82+
getDominatingDataClauses(mlir::Operation *computeConstructOp,
83+
mlir::DominanceInfo &domInfo,
84+
mlir::PostDominanceInfo &postDomInfo);
85+
6586
} // namespace acc
6687
} // namespace mlir
6788

mlir/lib/Dialect/OpenACC/Transforms/ACCImplicitData.cpp

Lines changed: 4 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -237,11 +237,6 @@ class ACCImplicitData : public acc::impl::ACCImplicitDataBase<ACCImplicitData> {
237237
void runOnOperation() override;
238238

239239
private:
240-
/// Collects all data clauses that dominate the compute construct.
241-
/// Needed to determine if a variable is already covered by an existing data
242-
/// clause.
243-
SmallVector<Value> getDominatingDataClauses(Operation *computeConstructOp);
244-
245240
/// Looks through the `dominatingDataClauses` to find the original data clause
246241
/// op for an alias. Returns nullptr if no original data clause op is found.
247242
template <typename OpT>
@@ -300,62 +295,6 @@ static bool isCandidateForImplicitData(Value val, Region &accRegion) {
300295
return true;
301296
}
302297

303-
SmallVector<Value>
304-
ACCImplicitData::getDominatingDataClauses(Operation *computeConstructOp) {
305-
llvm::SmallSetVector<Value, 8> dominatingDataClauses;
306-
307-
llvm::TypeSwitch<Operation *>(computeConstructOp)
308-
.Case<acc::ParallelOp, acc::KernelsOp, acc::SerialOp>([&](auto op) {
309-
for (auto dataClause : op.getDataClauseOperands()) {
310-
dominatingDataClauses.insert(dataClause);
311-
}
312-
})
313-
.Default([](Operation *) {});
314-
315-
// Collect the data clauses from enclosing data constructs.
316-
Operation *currParentOp = computeConstructOp->getParentOp();
317-
while (currParentOp) {
318-
if (isa<acc::DataOp>(currParentOp)) {
319-
for (auto dataClause :
320-
dyn_cast<acc::DataOp>(currParentOp).getDataClauseOperands()) {
321-
dominatingDataClauses.insert(dataClause);
322-
}
323-
}
324-
currParentOp = currParentOp->getParentOp();
325-
}
326-
327-
// Find the enclosing function/subroutine
328-
auto funcOp = computeConstructOp->getParentOfType<FunctionOpInterface>();
329-
if (!funcOp)
330-
return dominatingDataClauses.takeVector();
331-
332-
// Walk the function to find `acc.declare_enter`/`acc.declare_exit` pairs that
333-
// dominate and post-dominate the compute construct and add their data
334-
// clauses to the list.
335-
auto &domInfo = this->getAnalysis<DominanceInfo>();
336-
auto &postDomInfo = this->getAnalysis<PostDominanceInfo>();
337-
funcOp->walk([&](acc::DeclareEnterOp declareEnterOp) {
338-
if (domInfo.dominates(declareEnterOp.getOperation(), computeConstructOp)) {
339-
// Collect all `acc.declare_exit` ops for this token.
340-
SmallVector<acc::DeclareExitOp> exits;
341-
for (auto *user : declareEnterOp.getToken().getUsers())
342-
if (auto declareExit = dyn_cast<acc::DeclareExitOp>(user))
343-
exits.push_back(declareExit);
344-
345-
// Only add clauses if every `acc.declare_exit` op post-dominates the
346-
// compute construct.
347-
if (!exits.empty() && llvm::all_of(exits, [&](acc::DeclareExitOp exitOp) {
348-
return postDomInfo.postDominates(exitOp, computeConstructOp);
349-
})) {
350-
for (auto dataClause : declareEnterOp.getDataClauseOperands())
351-
dominatingDataClauses.insert(dataClause);
352-
}
353-
}
354-
});
355-
356-
return dominatingDataClauses.takeVector();
357-
}
358-
359298
template <typename OpT>
360299
Operation *ACCImplicitData::getOriginalDataClauseOpForAlias(
361300
Value var, OpBuilder &builder, OpT computeConstructOp,
@@ -775,7 +714,10 @@ void ACCImplicitData::generateImplicitDataOps(
775714
LLVM_DEBUG(llvm::dbgs() << "== Generating clauses for ==\n"
776715
<< computeConstructOp << "\n");
777716
}
778-
auto dominatingDataClauses = getDominatingDataClauses(computeConstructOp);
717+
auto &domInfo = this->getAnalysis<DominanceInfo>();
718+
auto &postDomInfo = this->getAnalysis<PostDominanceInfo>();
719+
auto dominatingDataClauses =
720+
acc::getDominatingDataClauses(computeConstructOp, domInfo, postDomInfo);
779721
for (auto var : candidateVars) {
780722
auto newDataClauseOp = generateDataClauseOpForCandidate(
781723
var, module, builder, computeConstructOp, dominatingDataClauses,

mlir/lib/Dialect/OpenACC/Utils/OpenACCUtils.cpp

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,11 @@
99
#include "mlir/Dialect/OpenACC/OpenACCUtils.h"
1010

1111
#include "mlir/Dialect/OpenACC/OpenACC.h"
12+
#include "mlir/IR/Dominance.h"
1213
#include "mlir/IR/SymbolTable.h"
1314
#include "mlir/Interfaces/FunctionInterfaces.h"
1415
#include "mlir/Interfaces/ViewLikeInterface.h"
16+
#include "llvm/ADT/SetVector.h"
1517
#include "llvm/ADT/TypeSwitch.h"
1618
#include "llvm/IR/Intrinsics.h"
1719
#include "llvm/Support/Casting.h"
@@ -205,3 +207,62 @@ bool mlir::acc::isValidSymbolUse(mlir::Operation *user,
205207
bool hasDeclare = definingOp->hasAttr(mlir::acc::getDeclareAttrName());
206208
return hasDeclare;
207209
}
210+
211+
llvm::SmallVector<mlir::Value>
212+
mlir::acc::getDominatingDataClauses(mlir::Operation *computeConstructOp,
213+
mlir::DominanceInfo &domInfo,
214+
mlir::PostDominanceInfo &postDomInfo) {
215+
llvm::SmallSetVector<mlir::Value, 8> dominatingDataClauses;
216+
217+
llvm::TypeSwitch<mlir::Operation *>(computeConstructOp)
218+
.Case<mlir::acc::ParallelOp, mlir::acc::KernelsOp, mlir::acc::SerialOp>(
219+
[&](auto op) {
220+
for (auto dataClause : op.getDataClauseOperands()) {
221+
dominatingDataClauses.insert(dataClause);
222+
}
223+
})
224+
.Default([](mlir::Operation *) {});
225+
226+
// Collect the data clauses from enclosing data constructs.
227+
mlir::Operation *currParentOp = computeConstructOp->getParentOp();
228+
while (currParentOp) {
229+
if (mlir::isa<mlir::acc::DataOp>(currParentOp)) {
230+
for (auto dataClause : mlir::dyn_cast<mlir::acc::DataOp>(currParentOp)
231+
.getDataClauseOperands()) {
232+
dominatingDataClauses.insert(dataClause);
233+
}
234+
}
235+
currParentOp = currParentOp->getParentOp();
236+
}
237+
238+
// Find the enclosing function/subroutine
239+
auto funcOp =
240+
computeConstructOp->getParentOfType<mlir::FunctionOpInterface>();
241+
if (!funcOp)
242+
return dominatingDataClauses.takeVector();
243+
244+
// Walk the function to find `acc.declare_enter`/`acc.declare_exit` pairs that
245+
// dominate and post-dominate the compute construct and add their data
246+
// clauses to the list.
247+
funcOp->walk([&](mlir::acc::DeclareEnterOp declareEnterOp) {
248+
if (domInfo.dominates(declareEnterOp.getOperation(), computeConstructOp)) {
249+
// Collect all `acc.declare_exit` ops for this token.
250+
llvm::SmallVector<mlir::acc::DeclareExitOp> exits;
251+
for (auto *user : declareEnterOp.getToken().getUsers())
252+
if (auto declareExit = mlir::dyn_cast<mlir::acc::DeclareExitOp>(user))
253+
exits.push_back(declareExit);
254+
255+
// Only add clauses if every `acc.declare_exit` op post-dominates the
256+
// compute construct.
257+
if (!exits.empty() &&
258+
llvm::all_of(exits, [&](mlir::acc::DeclareExitOp exitOp) {
259+
return postDomInfo.postDominates(exitOp, computeConstructOp);
260+
})) {
261+
for (auto dataClause : declareEnterOp.getDataClauseOperands())
262+
dominatingDataClauses.insert(dataClause);
263+
}
264+
}
265+
});
266+
267+
return dominatingDataClauses.takeVector();
268+
}

0 commit comments

Comments
 (0)