1010
1111#include " mlir/Dialect/Func/IR/FuncOps.h"
1212#include " mlir/Dialect/OpenACC/OpenACC.h"
13+ #include " mlir/IR/Dominance.h"
1314#include " mlir/Pass/Pass.h"
1415#include " mlir/Transforms/RegionUtils.h"
1516#include " llvm/Support/ErrorHandling.h"
@@ -71,7 +72,55 @@ static void replaceAllUsesInAccComputeRegionsWith(Value orig, Value replacement,
7172}
7273
7374template <typename Op>
74- static void collectAndReplaceInRegion (Op &op, bool hostToDevice) {
75+ static void replaceAllUsesInUnstructuredComputeRegionWith (
76+ Op &op, llvm::SmallVector<std::pair<Value, Value>> &values,
77+ DominanceInfo &domInfo, PostDominanceInfo &postDomInfo) {
78+
79+ SmallVector<Operation *> exitOps;
80+ if constexpr (std::is_same_v<Op, acc::DeclareEnterOp>) {
81+ // For declare enter/exit pairs, collect all exit ops
82+ for (auto *user : op.getToken ().getUsers ()) {
83+ if (auto declareExit = dyn_cast<acc::DeclareExitOp>(user))
84+ exitOps.push_back (declareExit);
85+ }
86+ if (exitOps.empty ())
87+ return ;
88+ }
89+
90+ for (auto p : values) {
91+ Value hostVal = std::get<0 >(p);
92+ Value deviceVal = std::get<1 >(p);
93+ for (auto &use : llvm::make_early_inc_range (hostVal.getUses ())) {
94+ Operation *owner = use.getOwner ();
95+
96+ // Check It's the case that the acc entry operation dominates the use.
97+ if (!domInfo.dominates (op.getOperation (), owner))
98+ continue ;
99+
100+ // Check It's the case that at least one of the acc exit operations
101+ // post-dominates the use
102+ bool hasPostDominatingExit = false ;
103+ for (auto *exit : exitOps) {
104+ if (postDomInfo.postDominates (exit, owner)) {
105+ hasPostDominatingExit = true ;
106+ break ;
107+ }
108+ }
109+
110+ if (!hasPostDominatingExit)
111+ continue ;
112+
113+ if (insideAccComputeRegion (owner))
114+ use.set (deviceVal);
115+ }
116+ }
117+ }
118+
119+ template <typename Op>
120+ static void
121+ collectAndReplaceInRegion (Op &op, bool hostToDevice,
122+ DominanceInfo *domInfo = nullptr ,
123+ PostDominanceInfo *postDomInfo = nullptr ) {
75124 llvm::SmallVector<std::pair<Value, Value>> values;
76125
77126 if constexpr (std::is_same_v<Op, acc::LoopOp>) {
@@ -82,16 +131,25 @@ static void collectAndReplaceInRegion(Op &op, bool hostToDevice) {
82131 if constexpr (!std::is_same_v<Op, acc::KernelsOp> &&
83132 !std::is_same_v<Op, acc::DataOp> &&
84133 !std::is_same_v<Op, acc::DeclareOp> &&
85- !std::is_same_v<Op, acc::HostDataOp>) {
134+ !std::is_same_v<Op, acc::HostDataOp> &&
135+ !std::is_same_v<Op, acc::DeclareEnterOp>) {
86136 collectVars (op.getReductionOperands (), values, hostToDevice);
87137 collectVars (op.getPrivateOperands (), values, hostToDevice);
88138 collectVars (op.getFirstprivateOperands (), values, hostToDevice);
89139 }
90140 }
91141
92- for (auto p : values)
93- replaceAllUsesInAccComputeRegionsWith<Op>(std::get<0 >(p), std::get<1 >(p),
94- op.getRegion ());
142+ if constexpr (std::is_same_v<Op, acc::DeclareEnterOp>) {
143+ assert (domInfo && postDomInfo &&
144+ " Dominance info required for DeclareEnterOp" );
145+ replaceAllUsesInUnstructuredComputeRegionWith<Op>(op, values, *domInfo,
146+ *postDomInfo);
147+ } else {
148+ for (auto p : values) {
149+ replaceAllUsesInAccComputeRegionsWith<Op>(std::get<0 >(p), std::get<1 >(p),
150+ op.getRegion ());
151+ }
152+ }
95153}
96154
97155class LegalizeDataValuesInRegion
@@ -105,10 +163,16 @@ class LegalizeDataValuesInRegion
105163 func::FuncOp funcOp = getOperation ();
106164 bool replaceHostVsDevice = this ->hostToDevice .getValue ();
107165
166+ // Initialize dominance info
167+ DominanceInfo domInfo;
168+ PostDominanceInfo postDomInfo;
169+ bool computedDomInfo = false ;
170+
108171 funcOp.walk ([&](Operation *op) {
109172 if (!isa<ACC_COMPUTE_CONSTRUCT_AND_LOOP_OPS>(*op) &&
110173 !(isa<ACC_DATA_CONSTRUCT_STRUCTURED_OPS>(*op) &&
111- applyToAccDataConstruct))
174+ applyToAccDataConstruct) &&
175+ !isa<acc::DeclareEnterOp>(*op))
112176 return ;
113177
114178 if (auto parallelOp = dyn_cast<acc::ParallelOp>(*op)) {
@@ -125,6 +189,14 @@ class LegalizeDataValuesInRegion
125189 collectAndReplaceInRegion (declareOp, replaceHostVsDevice);
126190 } else if (auto hostDataOp = dyn_cast<acc::HostDataOp>(*op)) {
127191 collectAndReplaceInRegion (hostDataOp, replaceHostVsDevice);
192+ } else if (auto declareEnterOp = dyn_cast<acc::DeclareEnterOp>(*op)) {
193+ if (!computedDomInfo) {
194+ domInfo = DominanceInfo (funcOp);
195+ postDomInfo = PostDominanceInfo (funcOp);
196+ computedDomInfo = true ;
197+ }
198+ collectAndReplaceInRegion (declareEnterOp, replaceHostVsDevice, &domInfo,
199+ &postDomInfo);
128200 } else {
129201 llvm_unreachable (" unsupported acc region op" );
130202 }
0 commit comments