@@ -71,26 +71,43 @@ static void replaceAllUsesInAccComputeRegionsWith(Value orig, Value replacement,
7171 }
7272}
7373
74- // Helper function to process declare enter/exit pairs
75- static void processDeclareEnterExit (
76- acc::DeclareEnterOp op, llvm::SmallVector<std::pair<Value, Value>> &values,
74+ template < typename Op>
75+ static void replaceAllUsesInUnstructuredComputeRegionWith (
76+ Op & op, llvm::SmallVector<std::pair<Value, Value>> &values,
7777 DominanceInfo &domInfo, PostDominanceInfo &postDomInfo) {
78- // For declare enter/exit pairs, verify there is exactly one exit op using the
79- // token
80- if (!op.getToken ().hasOneUse ())
81- op.emitError (" declare enter token must have exactly one use" );
82- Operation *user = *op.getToken ().getUsers ().begin ();
83- auto declareExit = dyn_cast<acc::DeclareExitOp>(user);
84- if (!declareExit)
85- op.emitError (" declare enter token must be used by declare exit op" );
78+
79+ Operation *exitOp = op.getOperation ();
80+ if constexpr (std::is_same_v<Op, acc::DeclareEnterOp>) {
81+ // For declare enter/exit pairs, verify there is exactly one exit op using
82+ // the token
83+ if (!op.getToken ().hasOneUse ())
84+ op.emitError (" declare enter token must have exactly one use" );
85+ Operation *user = *op.getToken ().getUsers ().begin ();
86+ auto declareExit = dyn_cast<acc::DeclareExitOp>(user);
87+ if (!declareExit)
88+ op.emitError (" declare enter token must be used by declare exit op" );
89+ exitOp = declareExit;
90+ } else if constexpr (std::is_same_v<Op, acc::EnterDataOp>) {
91+ // For enter/exit data pairs, find the corresponding exit_data op
92+ Operation *nextOp = op.getOperation ()->getNextNode ();
93+ while (nextOp && !isa<acc::ExitDataOp>(nextOp))
94+ nextOp = nextOp->getNextNode ();
95+ if (!nextOp)
96+ op.emitError (" enter data must have a corresponding exit data op" );
97+ exitOp = nextOp;
98+ }
8699
87100 for (auto p : values) {
88101 Value hostVal = std::get<0 >(p);
89102 Value deviceVal = std::get<1 >(p);
90103 for (auto &use : llvm::make_early_inc_range (hostVal.getUses ())) {
91104 Operation *owner = use.getOwner ();
105+ // Check that:
106+ // It's the case that the acc entry operation dominates the use.
107+ // It's the case that one of the acc exit operations consuming the token
108+ // post-dominates the use
92109 if (!domInfo.dominates (op.getOperation (), owner) ||
93- !postDomInfo.postDominates (declareExit. getOperation () , owner))
110+ !postDomInfo.postDominates (exitOp , owner))
94111 continue ;
95112 if (insideAccComputeRegion (owner))
96113 use.set (deviceVal);
@@ -114,17 +131,20 @@ collectAndReplaceInRegion(Op &op, bool hostToDevice,
114131 !std::is_same_v<Op, acc::DataOp> &&
115132 !std::is_same_v<Op, acc::DeclareOp> &&
116133 !std::is_same_v<Op, acc::HostDataOp> &&
117- !std::is_same_v<Op, acc::DeclareEnterOp>) {
134+ !std::is_same_v<Op, acc::DeclareEnterOp> &&
135+ !std::is_same_v<Op, acc::EnterDataOp>) {
118136 collectVars (op.getReductionOperands (), values, hostToDevice);
119137 collectVars (op.getPrivateOperands (), values, hostToDevice);
120138 collectVars (op.getFirstprivateOperands (), values, hostToDevice);
121139 }
122140 }
123141
124- if constexpr (std::is_same_v<Op, acc::DeclareEnterOp>) {
142+ if constexpr (std::is_same_v<Op, acc::DeclareEnterOp> ||
143+ std::is_same_v<Op, acc::EnterDataOp>) {
125144 assert (domInfo && postDomInfo &&
126145 " Dominance info required for DeclareEnterOp" );
127- processDeclareEnterExit (op, values, *domInfo, *postDomInfo);
146+ replaceAllUsesInUnstructuredComputeRegionWith<Op>(op, values, *domInfo,
147+ *postDomInfo);
128148 } else {
129149 for (auto p : values) {
130150 replaceAllUsesInAccComputeRegionsWith<Op>(std::get<0 >(p), std::get<1 >(p),
@@ -151,7 +171,8 @@ class LegalizeDataValuesInRegion
151171 funcOp.walk ([&](Operation *op) {
152172 if (!isa<ACC_COMPUTE_CONSTRUCT_AND_LOOP_OPS>(*op) &&
153173 !(isa<ACC_DATA_CONSTRUCT_STRUCTURED_OPS>(*op) &&
154- applyToAccDataConstruct))
174+ applyToAccDataConstruct) &&
175+ !isa<acc::DeclareEnterOp>(*op) && !isa<acc::EnterDataOp>(*op))
155176 return ;
156177
157178 if (auto parallelOp = dyn_cast<acc::ParallelOp>(*op)) {
@@ -171,6 +192,9 @@ class LegalizeDataValuesInRegion
171192 } else if (auto declareEnterOp = dyn_cast<acc::DeclareEnterOp>(*op)) {
172193 collectAndReplaceInRegion (declareEnterOp, replaceHostVsDevice, &domInfo,
173194 &postDomInfo);
195+ } else if (auto enterDataOp = dyn_cast<acc::EnterDataOp>(*op)) {
196+ collectAndReplaceInRegion (enterDataOp, replaceHostVsDevice, &domInfo,
197+ &postDomInfo);
174198 } else {
175199 llvm_unreachable (" unsupported acc region op" );
176200 }
0 commit comments