1010
1111#include " mlir/Dialect/Func/IR/FuncOps.h"
1212#include " mlir/Dialect/OpenACC/OpenACC.h"
13+ #include " mlir/IR/Dominance.h"
1314#include " mlir/Pass/Pass.h"
1415#include " mlir/Transforms/RegionUtils.h"
1516#include " llvm/Support/ErrorHandling.h"
@@ -70,8 +71,38 @@ static void replaceAllUsesInAccComputeRegionsWith(Value orig, Value replacement,
7071 }
7172}
7273
74+ // Helper function to process declare enter/exit pairs
75+ static void processDeclareEnterExit (
76+ acc::DeclareEnterOp op, llvm::SmallVector<std::pair<Value, Value>> &values,
77+ DominanceInfo &domInfo, PostDominanceInfo &postDomInfo) {
78+ // For declare enter/exit pairs, verify there is exactly one exit op using the
79+ // token
80+ if (!op.getToken ().hasOneUse ())
81+ op.emitError (" declare enter token must have exactly one use" );
82+ Operation *user = *op.getToken ().getUsers ().begin ();
83+ auto declareExit = dyn_cast<acc::DeclareExitOp>(user);
84+ if (!declareExit)
85+ op.emitError (" declare enter token must be used by declare exit op" );
86+
87+ for (auto p : values) {
88+ Value hostVal = std::get<0 >(p);
89+ Value deviceVal = std::get<1 >(p);
90+ for (auto &use : llvm::make_early_inc_range (hostVal.getUses ())) {
91+ Operation *owner = use.getOwner ();
92+ if (!domInfo.dominates (op.getOperation (), owner) ||
93+ !postDomInfo.postDominates (declareExit.getOperation (), owner))
94+ continue ;
95+ if (insideAccComputeRegion (owner))
96+ use.set (deviceVal);
97+ }
98+ }
99+ }
100+
73101template <typename Op>
74- static void collectAndReplaceInRegion (Op &op, bool hostToDevice) {
102+ static void
103+ collectAndReplaceInRegion (Op &op, bool hostToDevice,
104+ DominanceInfo *domInfo = nullptr ,
105+ PostDominanceInfo *postDomInfo = nullptr ) {
75106 llvm::SmallVector<std::pair<Value, Value>> values;
76107
77108 if constexpr (std::is_same_v<Op, acc::LoopOp>) {
@@ -82,16 +113,24 @@ static void collectAndReplaceInRegion(Op &op, bool hostToDevice) {
82113 if constexpr (!std::is_same_v<Op, acc::KernelsOp> &&
83114 !std::is_same_v<Op, acc::DataOp> &&
84115 !std::is_same_v<Op, acc::DeclareOp> &&
85- !std::is_same_v<Op, acc::HostDataOp>) {
116+ !std::is_same_v<Op, acc::HostDataOp> &&
117+ !std::is_same_v<Op, acc::DeclareEnterOp>) {
86118 collectVars (op.getReductionOperands (), values, hostToDevice);
87119 collectVars (op.getPrivateOperands (), values, hostToDevice);
88120 collectVars (op.getFirstprivateOperands (), values, hostToDevice);
89121 }
90122 }
91123
92- for (auto p : values)
93- replaceAllUsesInAccComputeRegionsWith<Op>(std::get<0 >(p), std::get<1 >(p),
94- op.getRegion ());
124+ if constexpr (std::is_same_v<Op, acc::DeclareEnterOp>) {
125+ assert (domInfo && postDomInfo &&
126+ " Dominance info required for DeclareEnterOp" );
127+ processDeclareEnterExit (op, values, *domInfo, *postDomInfo);
128+ } else {
129+ for (auto p : values) {
130+ replaceAllUsesInAccComputeRegionsWith<Op>(std::get<0 >(p), std::get<1 >(p),
131+ op.getRegion ());
132+ }
133+ }
95134}
96135
97136class LegalizeDataValuesInRegion
@@ -105,6 +144,10 @@ class LegalizeDataValuesInRegion
105144 func::FuncOp funcOp = getOperation ();
106145 bool replaceHostVsDevice = this ->hostToDevice .getValue ();
107146
147+ // Get dominance info for the function
148+ DominanceInfo domInfo (funcOp);
149+ PostDominanceInfo postDomInfo (funcOp);
150+
108151 funcOp.walk ([&](Operation *op) {
109152 if (!isa<ACC_COMPUTE_CONSTRUCT_AND_LOOP_OPS>(*op) &&
110153 !(isa<ACC_DATA_CONSTRUCT_STRUCTURED_OPS>(*op) &&
@@ -125,6 +168,9 @@ class LegalizeDataValuesInRegion
125168 collectAndReplaceInRegion (declareOp, replaceHostVsDevice);
126169 } else if (auto hostDataOp = dyn_cast<acc::HostDataOp>(*op)) {
127170 collectAndReplaceInRegion (hostDataOp, replaceHostVsDevice);
171+ } else if (auto declareEnterOp = dyn_cast<acc::DeclareEnterOp>(*op)) {
172+ collectAndReplaceInRegion (declareEnterOp, replaceHostVsDevice, &domInfo,
173+ &postDomInfo);
128174 } else {
129175 llvm_unreachable (" unsupported acc region op" );
130176 }
0 commit comments