Skip to content

Commit 4814f8f

Browse files
committed
change to support multiple exits
1 parent 823d31d commit 4814f8f

File tree

1 file changed

+27
-34
lines changed

1 file changed

+27
-34
lines changed

mlir/lib/Dialect/OpenACC/Transforms/LegalizeDataValues.cpp

Lines changed: 27 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -76,39 +76,41 @@ static void replaceAllUsesInUnstructuredComputeRegionWith(
7676
Op &op, llvm::SmallVector<std::pair<Value, Value>> &values,
7777
DominanceInfo &domInfo, PostDominanceInfo &postDomInfo) {
7878

79-
Operation *exitOp = op.getOperation();
79+
SmallVector<Operation *> exitOps;
8080
if constexpr (std::is_same_v<Op, acc::DeclareEnterOp>) {
81-
// For declare enter/exit pairs, verify there is exactly one exit op using
82-
// the token
83-
if (!op.getToken().hasOneUse())
84-
op.emitError("declare enter token must have exactly one use");
85-
Operation *user = *op.getToken().getUsers().begin();
86-
auto declareExit = dyn_cast<acc::DeclareExitOp>(user);
87-
if (!declareExit)
88-
op.emitError("declare enter token must be used by declare exit op");
89-
exitOp = declareExit;
90-
} else if constexpr (std::is_same_v<Op, acc::EnterDataOp>) {
91-
// For enter/exit data pairs, find the corresponding exit_data op
92-
Operation *nextOp = op.getOperation()->getNextNode();
93-
while (nextOp && !isa<acc::ExitDataOp>(nextOp))
94-
nextOp = nextOp->getNextNode();
95-
if (!nextOp)
96-
op.emitError("enter data must have a corresponding exit data op");
97-
exitOp = nextOp;
81+
// For declare enter/exit pairs, collect all exit ops
82+
for (auto *user : op.getToken().getUsers()) {
83+
if (auto declareExit = dyn_cast<acc::DeclareExitOp>(user))
84+
exitOps.push_back(declareExit);
85+
}
86+
if (exitOps.empty())
87+
op.emitError(
88+
"declare enter token must be used by at least one declare exit op");
9889
}
9990

10091
for (auto p : values) {
10192
Value hostVal = std::get<0>(p);
10293
Value deviceVal = std::get<1>(p);
10394
for (auto &use : llvm::make_early_inc_range(hostVal.getUses())) {
10495
Operation *owner = use.getOwner();
105-
// Check that:
106-
// It's the case that the acc entry operation dominates the use.
107-
// It's the case that one of the acc exit operations consuming the token
96+
97+
// Check It's the case that the acc entry operation dominates the use.
98+
if (!domInfo.dominates(op.getOperation(), owner))
99+
continue;
100+
101+
// Check It's the case that at least one of the acc exit operations
108102
// post-dominates the use
109-
if (!domInfo.dominates(op.getOperation(), owner) ||
110-
!postDomInfo.postDominates(exitOp, owner))
103+
bool hasPostDominatingExit = false;
104+
for (auto *exit : exitOps) {
105+
if (postDomInfo.postDominates(exit, owner)) {
106+
hasPostDominatingExit = true;
107+
break;
108+
}
109+
}
110+
111+
if (!hasPostDominatingExit)
111112
continue;
113+
112114
if (insideAccComputeRegion(owner))
113115
use.set(deviceVal);
114116
}
@@ -131,8 +133,7 @@ collectAndReplaceInRegion(Op &op, bool hostToDevice,
131133
!std::is_same_v<Op, acc::DataOp> &&
132134
!std::is_same_v<Op, acc::DeclareOp> &&
133135
!std::is_same_v<Op, acc::HostDataOp> &&
134-
!std::is_same_v<Op, acc::DeclareEnterOp> &&
135-
!std::is_same_v<Op, acc::EnterDataOp>) {
136+
!std::is_same_v<Op, acc::DeclareEnterOp>) {
136137
collectVars(op.getReductionOperands(), values, hostToDevice);
137138
collectVars(op.getPrivateOperands(), values, hostToDevice);
138139
collectVars(op.getFirstprivateOperands(), values, hostToDevice);
@@ -173,7 +174,7 @@ class LegalizeDataValuesInRegion
173174
if (!isa<ACC_COMPUTE_CONSTRUCT_AND_LOOP_OPS>(*op) &&
174175
!(isa<ACC_DATA_CONSTRUCT_STRUCTURED_OPS>(*op) &&
175176
applyToAccDataConstruct) &&
176-
!isa<acc::DeclareEnterOp>(*op) && !isa<acc::EnterDataOp>(*op))
177+
!isa<acc::DeclareEnterOp>(*op))
177178
return;
178179

179180
if (auto parallelOp = dyn_cast<acc::ParallelOp>(*op)) {
@@ -198,14 +199,6 @@ class LegalizeDataValuesInRegion
198199
}
199200
collectAndReplaceInRegion(declareEnterOp, replaceHostVsDevice, &domInfo,
200201
&postDomInfo);
201-
} else if (auto enterDataOp = dyn_cast<acc::EnterDataOp>(*op)) {
202-
if (!computedDomInfo) {
203-
domInfo = DominanceInfo(funcOp);
204-
postDomInfo = PostDominanceInfo(funcOp);
205-
computedDomInfo = true;
206-
}
207-
collectAndReplaceInRegion(enterDataOp, replaceHostVsDevice, &domInfo,
208-
&postDomInfo);
209202
} else {
210203
llvm_unreachable("unsupported acc region op");
211204
}

0 commit comments

Comments
 (0)