Skip to content

Commit 4c458ea

Browse files
committed
add support for acc.declare_enter
1 parent b02f2e8 commit 4c458ea

File tree

2 files changed

+53
-6
lines changed

2 files changed

+53
-6
lines changed

mlir/include/mlir/Dialect/OpenACC/OpenACC.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,8 @@
5858
#define ACC_COMPUTE_CONSTRUCT_AND_LOOP_OPS \
5959
ACC_COMPUTE_CONSTRUCT_OPS, mlir::acc::LoopOp
6060
#define ACC_DATA_CONSTRUCT_STRUCTURED_OPS \
61-
mlir::acc::DataOp, mlir::acc::DeclareOp, mlir::acc::HostDataOp
61+
mlir::acc::DataOp, mlir::acc::DeclareOp, mlir::acc::HostDataOp, \
62+
mlir::acc::DeclareEnterOp
6263
#define ACC_DATA_CONSTRUCT_UNSTRUCTURED_OPS \
6364
mlir::acc::EnterDataOp, mlir::acc::ExitDataOp, mlir::acc::UpdateOp, \
6465
mlir::acc::DeclareEnterOp, mlir::acc::DeclareExitOp

mlir/lib/Dialect/OpenACC/Transforms/LegalizeDataValues.cpp

Lines changed: 51 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
#include "mlir/Dialect/Func/IR/FuncOps.h"
1212
#include "mlir/Dialect/OpenACC/OpenACC.h"
13+
#include "mlir/IR/Dominance.h"
1314
#include "mlir/Pass/Pass.h"
1415
#include "mlir/Transforms/RegionUtils.h"
1516
#include "llvm/Support/ErrorHandling.h"
@@ -70,8 +71,38 @@ static void replaceAllUsesInAccComputeRegionsWith(Value orig, Value replacement,
7071
}
7172
}
7273

74+
// Helper function to process declare enter/exit pairs
75+
static void processDeclareEnterExit(
76+
acc::DeclareEnterOp op, llvm::SmallVector<std::pair<Value, Value>> &values,
77+
DominanceInfo &domInfo, PostDominanceInfo &postDomInfo) {
78+
// For declare enter/exit pairs, verify there is exactly one exit op using the
79+
// token
80+
if (!op.getToken().hasOneUse())
81+
op.emitError("declare enter token must have exactly one use");
82+
Operation *user = *op.getToken().getUsers().begin();
83+
auto declareExit = dyn_cast<acc::DeclareExitOp>(user);
84+
if (!declareExit)
85+
op.emitError("declare enter token must be used by declare exit op");
86+
87+
for (auto p : values) {
88+
Value hostVal = std::get<0>(p);
89+
Value deviceVal = std::get<1>(p);
90+
for (auto &use : llvm::make_early_inc_range(hostVal.getUses())) {
91+
Operation *owner = use.getOwner();
92+
if (!domInfo.dominates(op.getOperation(), owner) ||
93+
!postDomInfo.postDominates(declareExit.getOperation(), owner))
94+
continue;
95+
if (insideAccComputeRegion(owner))
96+
use.set(deviceVal);
97+
}
98+
}
99+
}
100+
73101
template <typename Op>
74-
static void collectAndReplaceInRegion(Op &op, bool hostToDevice) {
102+
static void
103+
collectAndReplaceInRegion(Op &op, bool hostToDevice,
104+
DominanceInfo *domInfo = nullptr,
105+
PostDominanceInfo *postDomInfo = nullptr) {
75106
llvm::SmallVector<std::pair<Value, Value>> values;
76107

77108
if constexpr (std::is_same_v<Op, acc::LoopOp>) {
@@ -82,16 +113,24 @@ static void collectAndReplaceInRegion(Op &op, bool hostToDevice) {
82113
if constexpr (!std::is_same_v<Op, acc::KernelsOp> &&
83114
!std::is_same_v<Op, acc::DataOp> &&
84115
!std::is_same_v<Op, acc::DeclareOp> &&
85-
!std::is_same_v<Op, acc::HostDataOp>) {
116+
!std::is_same_v<Op, acc::HostDataOp> &&
117+
!std::is_same_v<Op, acc::DeclareEnterOp>) {
86118
collectVars(op.getReductionOperands(), values, hostToDevice);
87119
collectVars(op.getPrivateOperands(), values, hostToDevice);
88120
collectVars(op.getFirstprivateOperands(), values, hostToDevice);
89121
}
90122
}
91123

92-
for (auto p : values)
93-
replaceAllUsesInAccComputeRegionsWith<Op>(std::get<0>(p), std::get<1>(p),
94-
op.getRegion());
124+
if constexpr (std::is_same_v<Op, acc::DeclareEnterOp>) {
125+
assert(domInfo && postDomInfo &&
126+
"Dominance info required for DeclareEnterOp");
127+
processDeclareEnterExit(op, values, *domInfo, *postDomInfo);
128+
} else {
129+
for (auto p : values) {
130+
replaceAllUsesInAccComputeRegionsWith<Op>(std::get<0>(p), std::get<1>(p),
131+
op.getRegion());
132+
}
133+
}
95134
}
96135

97136
class LegalizeDataValuesInRegion
@@ -105,6 +144,10 @@ class LegalizeDataValuesInRegion
105144
func::FuncOp funcOp = getOperation();
106145
bool replaceHostVsDevice = this->hostToDevice.getValue();
107146

147+
// Get dominance info for the function
148+
DominanceInfo domInfo(funcOp);
149+
PostDominanceInfo postDomInfo(funcOp);
150+
108151
funcOp.walk([&](Operation *op) {
109152
if (!isa<ACC_COMPUTE_CONSTRUCT_AND_LOOP_OPS>(*op) &&
110153
!(isa<ACC_DATA_CONSTRUCT_STRUCTURED_OPS>(*op) &&
@@ -125,6 +168,9 @@ class LegalizeDataValuesInRegion
125168
collectAndReplaceInRegion(declareOp, replaceHostVsDevice);
126169
} else if (auto hostDataOp = dyn_cast<acc::HostDataOp>(*op)) {
127170
collectAndReplaceInRegion(hostDataOp, replaceHostVsDevice);
171+
} else if (auto declareEnterOp = dyn_cast<acc::DeclareEnterOp>(*op)) {
172+
collectAndReplaceInRegion(declareEnterOp, replaceHostVsDevice, &domInfo,
173+
&postDomInfo);
128174
} else {
129175
llvm_unreachable("unsupported acc region op");
130176
}

0 commit comments

Comments
 (0)