Skip to content

Commit 03b89b1

Browse files
committed
add support for enter data
1 parent 4c458ea commit 03b89b1

File tree

2 files changed

+41
-18
lines changed

2 files changed

+41
-18
lines changed

mlir/include/mlir/Dialect/OpenACC/OpenACC.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,7 @@
5858
#define ACC_COMPUTE_CONSTRUCT_AND_LOOP_OPS \
5959
ACC_COMPUTE_CONSTRUCT_OPS, mlir::acc::LoopOp
6060
#define ACC_DATA_CONSTRUCT_STRUCTURED_OPS \
61-
mlir::acc::DataOp, mlir::acc::DeclareOp, mlir::acc::HostDataOp, \
62-
mlir::acc::DeclareEnterOp
61+
mlir::acc::DataOp, mlir::acc::DeclareOp, mlir::acc::HostDataOp
6362
#define ACC_DATA_CONSTRUCT_UNSTRUCTURED_OPS \
6463
mlir::acc::EnterDataOp, mlir::acc::ExitDataOp, mlir::acc::UpdateOp, \
6564
mlir::acc::DeclareEnterOp, mlir::acc::DeclareExitOp

mlir/lib/Dialect/OpenACC/Transforms/LegalizeDataValues.cpp

Lines changed: 40 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -71,26 +71,43 @@ static void replaceAllUsesInAccComputeRegionsWith(Value orig, Value replacement,
7171
}
7272
}
7373

74-
// Helper function to process declare enter/exit pairs
75-
static void processDeclareEnterExit(
76-
acc::DeclareEnterOp op, llvm::SmallVector<std::pair<Value, Value>> &values,
74+
template <typename Op>
75+
static void replaceAllUsesInUnstructuredComputeRegionWith(
76+
Op &op, llvm::SmallVector<std::pair<Value, Value>> &values,
7777
DominanceInfo &domInfo, PostDominanceInfo &postDomInfo) {
78-
// For declare enter/exit pairs, verify there is exactly one exit op using the
79-
// token
80-
if (!op.getToken().hasOneUse())
81-
op.emitError("declare enter token must have exactly one use");
82-
Operation *user = *op.getToken().getUsers().begin();
83-
auto declareExit = dyn_cast<acc::DeclareExitOp>(user);
84-
if (!declareExit)
85-
op.emitError("declare enter token must be used by declare exit op");
78+
79+
Operation *exitOp = op.getOperation();
80+
if constexpr (std::is_same_v<Op, acc::DeclareEnterOp>) {
81+
// For declare enter/exit pairs, verify there is exactly one exit op using
82+
// the token
83+
if (!op.getToken().hasOneUse())
84+
op.emitError("declare enter token must have exactly one use");
85+
Operation *user = *op.getToken().getUsers().begin();
86+
auto declareExit = dyn_cast<acc::DeclareExitOp>(user);
87+
if (!declareExit)
88+
op.emitError("declare enter token must be used by declare exit op");
89+
exitOp = declareExit;
90+
} else if constexpr (std::is_same_v<Op, acc::EnterDataOp>) {
91+
// For enter/exit data pairs, find the corresponding exit_data op
92+
Operation *nextOp = op.getOperation()->getNextNode();
93+
while (nextOp && !isa<acc::ExitDataOp>(nextOp))
94+
nextOp = nextOp->getNextNode();
95+
if (!nextOp)
96+
op.emitError("enter data must have a corresponding exit data op");
97+
exitOp = nextOp;
98+
}
8699

87100
for (auto p : values) {
88101
Value hostVal = std::get<0>(p);
89102
Value deviceVal = std::get<1>(p);
90103
for (auto &use : llvm::make_early_inc_range(hostVal.getUses())) {
91104
Operation *owner = use.getOwner();
105+
// Check that:
106+
// It's the case that the acc entry operation dominates the use.
107+
// It's the case that one of the acc exit operations consuming the token
108+
// post-dominates the use
92109
if (!domInfo.dominates(op.getOperation(), owner) ||
93-
!postDomInfo.postDominates(declareExit.getOperation(), owner))
110+
!postDomInfo.postDominates(exitOp, owner))
94111
continue;
95112
if (insideAccComputeRegion(owner))
96113
use.set(deviceVal);
@@ -114,17 +131,20 @@ collectAndReplaceInRegion(Op &op, bool hostToDevice,
114131
!std::is_same_v<Op, acc::DataOp> &&
115132
!std::is_same_v<Op, acc::DeclareOp> &&
116133
!std::is_same_v<Op, acc::HostDataOp> &&
117-
!std::is_same_v<Op, acc::DeclareEnterOp>) {
134+
!std::is_same_v<Op, acc::DeclareEnterOp> &&
135+
!std::is_same_v<Op, acc::EnterDataOp>) {
118136
collectVars(op.getReductionOperands(), values, hostToDevice);
119137
collectVars(op.getPrivateOperands(), values, hostToDevice);
120138
collectVars(op.getFirstprivateOperands(), values, hostToDevice);
121139
}
122140
}
123141

124-
if constexpr (std::is_same_v<Op, acc::DeclareEnterOp>) {
142+
if constexpr (std::is_same_v<Op, acc::DeclareEnterOp> ||
143+
std::is_same_v<Op, acc::EnterDataOp>) {
125144
assert(domInfo && postDomInfo &&
126145
"Dominance info required for DeclareEnterOp");
127-
processDeclareEnterExit(op, values, *domInfo, *postDomInfo);
146+
replaceAllUsesInUnstructuredComputeRegionWith<Op>(op, values, *domInfo,
147+
*postDomInfo);
128148
} else {
129149
for (auto p : values) {
130150
replaceAllUsesInAccComputeRegionsWith<Op>(std::get<0>(p), std::get<1>(p),
@@ -151,7 +171,8 @@ class LegalizeDataValuesInRegion
151171
funcOp.walk([&](Operation *op) {
152172
if (!isa<ACC_COMPUTE_CONSTRUCT_AND_LOOP_OPS>(*op) &&
153173
!(isa<ACC_DATA_CONSTRUCT_STRUCTURED_OPS>(*op) &&
154-
applyToAccDataConstruct))
174+
applyToAccDataConstruct) &&
175+
!isa<acc::DeclareEnterOp>(*op) && !isa<acc::EnterDataOp>(*op))
155176
return;
156177

157178
if (auto parallelOp = dyn_cast<acc::ParallelOp>(*op)) {
@@ -171,6 +192,9 @@ class LegalizeDataValuesInRegion
171192
} else if (auto declareEnterOp = dyn_cast<acc::DeclareEnterOp>(*op)) {
172193
collectAndReplaceInRegion(declareEnterOp, replaceHostVsDevice, &domInfo,
173194
&postDomInfo);
195+
} else if (auto enterDataOp = dyn_cast<acc::EnterDataOp>(*op)) {
196+
collectAndReplaceInRegion(enterDataOp, replaceHostVsDevice, &domInfo,
197+
&postDomInfo);
174198
} else {
175199
llvm_unreachable("unsupported acc region op");
176200
}

0 commit comments

Comments
 (0)