1- // ===- LegalizeData .cpp - ------ -------------------------------------------===//
1+ // ===- LegalizeDataValues .cpp - -------------------------------------------===//
22//
33// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
44// See https://llvm.org/LICENSE.txt for license information.
1212#include " mlir/Dialect/OpenACC/OpenACC.h"
1313#include " mlir/Pass/Pass.h"
1414#include " mlir/Transforms/RegionUtils.h"
15+ #include " llvm/Support/ErrorHandling.h"
1516
1617namespace mlir {
1718namespace acc {
18- #define GEN_PASS_DEF_LEGALIZEDATAINREGION
19+ #define GEN_PASS_DEF_LEGALIZEDATAVALUESINREGION
1920#include " mlir/Dialect/OpenACC/Transforms/Passes.h.inc"
2021} // namespace acc
2122} // namespace mlir
@@ -24,6 +25,17 @@ using namespace mlir;
2425
2526namespace {
2627
28+ static bool insideAccComputeRegion (mlir::Operation *op) {
29+ mlir::Operation *parent{op->getParentOp ()};
30+ while (parent) {
31+ if (isa<ACC_COMPUTE_CONSTRUCT_OPS>(parent)) {
32+ return true ;
33+ }
34+ parent = parent->getParentOp ();
35+ }
36+ return false ;
37+ }
38+
2739static void collectPtrs (mlir::ValueRange operands,
2840 llvm::SmallVector<std::pair<Value, Value>> &values,
2941 bool hostToDevice) {
@@ -39,6 +51,25 @@ static void collectPtrs(mlir::ValueRange operands,
3951 }
4052}
4153
54+ template <typename Op>
55+ static void replaceAllUsesInAccComputeRegionsWith (Value orig, Value replacement,
56+ Region &outerRegion) {
57+ for (auto &use : llvm::make_early_inc_range (orig.getUses ())) {
58+ if (outerRegion.isAncestor (use.getOwner ()->getParentRegion ())) {
59+ if constexpr (std::is_same_v<Op, acc::DataOp> ||
60+ std::is_same_v<Op, acc::DeclareOp>) {
61+ // For data construct regions, only replace uses in contained compute
62+ // regions.
63+ if (insideAccComputeRegion (use.getOwner ())) {
64+ use.set (replacement);
65+ }
66+ } else {
67+ use.set (replacement);
68+ }
69+ }
70+ }
71+ }
72+
4273template <typename Op>
4374static void collectAndReplaceInRegion (Op &op, bool hostToDevice) {
4475 llvm::SmallVector<std::pair<Value, Value>> values;
@@ -48,26 +79,35 @@ static void collectAndReplaceInRegion(Op &op, bool hostToDevice) {
4879 collectPtrs (op.getPrivateOperands (), values, hostToDevice);
4980 } else {
5081 collectPtrs (op.getDataClauseOperands (), values, hostToDevice);
51- if constexpr (!std::is_same_v<Op, acc::KernelsOp>) {
82+ if constexpr (!std::is_same_v<Op, acc::KernelsOp> &&
83+ !std::is_same_v<Op, acc::DataOp> &&
84+ !std::is_same_v<Op, acc::DeclareOp>) {
5285 collectPtrs (op.getReductionOperands (), values, hostToDevice);
5386 collectPtrs (op.getGangPrivateOperands (), values, hostToDevice);
5487 collectPtrs (op.getGangFirstPrivateOperands (), values, hostToDevice);
5588 }
5689 }
5790
5891 for (auto p : values)
59- replaceAllUsesInRegionWith (std::get<0 >(p), std::get<1 >(p), op.getRegion ());
92+ replaceAllUsesInAccComputeRegionsWith<Op>(std::get<0 >(p), std::get<1 >(p),
93+ op.getRegion ());
6094}
6195
62- struct LegalizeDataInRegion
63- : public acc::impl::LegalizeDataInRegionBase<LegalizeDataInRegion> {
96+ class LegalizeDataValuesInRegion
97+ : public acc::impl::LegalizeDataValuesInRegionBase<
98+ LegalizeDataValuesInRegion> {
99+ public:
100+ using LegalizeDataValuesInRegionBase<
101+ LegalizeDataValuesInRegion>::LegalizeDataValuesInRegionBase;
64102
65103 void runOnOperation () override {
66104 func::FuncOp funcOp = getOperation ();
67105 bool replaceHostVsDevice = this ->hostToDevice .getValue ();
68106
69107 funcOp.walk ([&](Operation *op) {
70- if (!isa<ACC_COMPUTE_CONSTRUCT_OPS>(*op) && !isa<acc::LoopOp>(*op))
108+ if (!isa<ACC_COMPUTE_CONSTRUCT_AND_LOOP_OPS>(*op) &&
109+ !(isa<ACC_DATA_CONSTRUCT_STRUCTURED_OPS>(*op) &&
110+ applyToAccDataConstruct))
71111 return ;
72112
73113 if (auto parallelOp = dyn_cast<acc::ParallelOp>(*op)) {
@@ -78,14 +118,15 @@ struct LegalizeDataInRegion
78118 collectAndReplaceInRegion (kernelsOp, replaceHostVsDevice);
79119 } else if (auto loopOp = dyn_cast<acc::LoopOp>(*op)) {
80120 collectAndReplaceInRegion (loopOp, replaceHostVsDevice);
121+ } else if (auto dataOp = dyn_cast<acc::DataOp>(*op)) {
122+ collectAndReplaceInRegion (dataOp, replaceHostVsDevice);
123+ } else if (auto declareOp = dyn_cast<acc::DeclareOp>(*op)) {
124+ collectAndReplaceInRegion (declareOp, replaceHostVsDevice);
125+ } else {
126+ llvm_unreachable (" unsupported acc region op" );
81127 }
82128 });
83129 }
84130};
85131
86132} // end anonymous namespace
87-
88- std::unique_ptr<OperationPass<func::FuncOp>>
89- mlir::acc::createLegalizeDataInRegion () {
90- return std::make_unique<LegalizeDataInRegion>();
91- }
0 commit comments