@@ -76,39 +76,41 @@ static void replaceAllUsesInUnstructuredComputeRegionWith(
7676 Op &op, llvm::SmallVector<std::pair<Value, Value>> &values,
7777 DominanceInfo &domInfo, PostDominanceInfo &postDomInfo) {
7878
79- Operation *exitOp = op. getOperation () ;
79+ SmallVector< Operation *> exitOps ;
8080 if constexpr (std::is_same_v<Op, acc::DeclareEnterOp>) {
81- // For declare enter/exit pairs, verify there is exactly one exit op using
82- // the token
83- if (!op.getToken ().hasOneUse ())
84- op.emitError (" declare enter token must have exactly one use" );
85- Operation *user = *op.getToken ().getUsers ().begin ();
86- auto declareExit = dyn_cast<acc::DeclareExitOp>(user);
87- if (!declareExit)
88- op.emitError (" declare enter token must be used by declare exit op" );
89- exitOp = declareExit;
90- } else if constexpr (std::is_same_v<Op, acc::EnterDataOp>) {
91- // For enter/exit data pairs, find the corresponding exit_data op
92- Operation *nextOp = op.getOperation ()->getNextNode ();
93- while (nextOp && !isa<acc::ExitDataOp>(nextOp))
94- nextOp = nextOp->getNextNode ();
95- if (!nextOp)
96- op.emitError (" enter data must have a corresponding exit data op" );
97- exitOp = nextOp;
81+ // For declare enter/exit pairs, collect all exit ops
82+ for (auto *user : op.getToken ().getUsers ()) {
83+ if (auto declareExit = dyn_cast<acc::DeclareExitOp>(user))
84+ exitOps.push_back (declareExit);
85+ }
86+ if (exitOps.empty ())
87+ op.emitError (
88+ " declare enter token must be used by at least one declare exit op" );
9889 }
9990
10091 for (auto p : values) {
10192 Value hostVal = std::get<0 >(p);
10293 Value deviceVal = std::get<1 >(p);
10394 for (auto &use : llvm::make_early_inc_range (hostVal.getUses ())) {
10495 Operation *owner = use.getOwner ();
105- // Check that:
106- // It's the case that the acc entry operation dominates the use.
107- // It's the case that one of the acc exit operations consuming the token
96+
97+ // Check It's the case that the acc entry operation dominates the use.
98+ if (!domInfo.dominates (op.getOperation (), owner))
99+ continue ;
100+
101+ // Check It's the case that at least one of the acc exit operations
108102 // post-dominates the use
109- if (!domInfo.dominates (op.getOperation (), owner) ||
110- !postDomInfo.postDominates (exitOp, owner))
103+ bool hasPostDominatingExit = false ;
104+ for (auto *exit : exitOps) {
105+ if (postDomInfo.postDominates (exit, owner)) {
106+ hasPostDominatingExit = true ;
107+ break ;
108+ }
109+ }
110+
111+ if (!hasPostDominatingExit)
111112 continue ;
113+
112114 if (insideAccComputeRegion (owner))
113115 use.set (deviceVal);
114116 }
@@ -131,8 +133,7 @@ collectAndReplaceInRegion(Op &op, bool hostToDevice,
131133 !std::is_same_v<Op, acc::DataOp> &&
132134 !std::is_same_v<Op, acc::DeclareOp> &&
133135 !std::is_same_v<Op, acc::HostDataOp> &&
134- !std::is_same_v<Op, acc::DeclareEnterOp> &&
135- !std::is_same_v<Op, acc::EnterDataOp>) {
136+ !std::is_same_v<Op, acc::DeclareEnterOp>) {
136137 collectVars (op.getReductionOperands (), values, hostToDevice);
137138 collectVars (op.getPrivateOperands (), values, hostToDevice);
138139 collectVars (op.getFirstprivateOperands (), values, hostToDevice);
@@ -173,7 +174,7 @@ class LegalizeDataValuesInRegion
173174 if (!isa<ACC_COMPUTE_CONSTRUCT_AND_LOOP_OPS>(*op) &&
174175 !(isa<ACC_DATA_CONSTRUCT_STRUCTURED_OPS>(*op) &&
175176 applyToAccDataConstruct) &&
176- !isa<acc::DeclareEnterOp>(*op) && !isa<acc::EnterDataOp>(*op) )
177+ !isa<acc::DeclareEnterOp>(*op))
177178 return ;
178179
179180 if (auto parallelOp = dyn_cast<acc::ParallelOp>(*op)) {
@@ -198,14 +199,6 @@ class LegalizeDataValuesInRegion
198199 }
199200 collectAndReplaceInRegion (declareEnterOp, replaceHostVsDevice, &domInfo,
200201 &postDomInfo);
201- } else if (auto enterDataOp = dyn_cast<acc::EnterDataOp>(*op)) {
202- if (!computedDomInfo) {
203- domInfo = DominanceInfo (funcOp);
204- postDomInfo = PostDominanceInfo (funcOp);
205- computedDomInfo = true ;
206- }
207- collectAndReplaceInRegion (enterDataOp, replaceHostVsDevice, &domInfo,
208- &postDomInfo);
209202 } else {
210203 llvm_unreachable (" unsupported acc region op" );
211204 }
0 commit comments