diff --git a/mlir/lib/Dialect/OpenACC/Transforms/LegalizeDataValues.cpp b/mlir/lib/Dialect/OpenACC/Transforms/LegalizeDataValues.cpp index f2abeab744d17..6d316d282278d 100644 --- a/mlir/lib/Dialect/OpenACC/Transforms/LegalizeDataValues.cpp +++ b/mlir/lib/Dialect/OpenACC/Transforms/LegalizeDataValues.cpp @@ -10,6 +10,7 @@ #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/OpenACC/OpenACC.h" +#include "mlir/IR/Dominance.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/RegionUtils.h" #include "llvm/Support/ErrorHandling.h" @@ -71,7 +72,55 @@ static void replaceAllUsesInAccComputeRegionsWith(Value orig, Value replacement, } template -static void collectAndReplaceInRegion(Op &op, bool hostToDevice) { +static void replaceAllUsesInUnstructuredComputeRegionWith( + Op &op, llvm::SmallVector> &values, + DominanceInfo &domInfo, PostDominanceInfo &postDomInfo) { + + SmallVector exitOps; + if constexpr (std::is_same_v) { + // For declare enter/exit pairs, collect all exit ops + for (auto *user : op.getToken().getUsers()) { + if (auto declareExit = dyn_cast(user)) + exitOps.push_back(declareExit); + } + if (exitOps.empty()) + return; + } + + for (auto p : values) { + Value hostVal = std::get<0>(p); + Value deviceVal = std::get<1>(p); + for (auto &use : llvm::make_early_inc_range(hostVal.getUses())) { + Operation *owner = use.getOwner(); + + // Check It's the case that the acc entry operation dominates the use. + if (!domInfo.dominates(op.getOperation(), owner)) + continue; + + // Check It's the case that at least one of the acc exit operations + // post-dominates the use + bool hasPostDominatingExit = false; + for (auto *exit : exitOps) { + if (postDomInfo.postDominates(exit, owner)) { + hasPostDominatingExit = true; + break; + } + } + + if (!hasPostDominatingExit) + continue; + + if (insideAccComputeRegion(owner)) + use.set(deviceVal); + } + } +} + +template +static void +collectAndReplaceInRegion(Op &op, bool hostToDevice, + DominanceInfo *domInfo = nullptr, + PostDominanceInfo *postDomInfo = nullptr) { llvm::SmallVector> values; if constexpr (std::is_same_v) { @@ -82,16 +131,25 @@ static void collectAndReplaceInRegion(Op &op, bool hostToDevice) { if constexpr (!std::is_same_v && !std::is_same_v && !std::is_same_v && - !std::is_same_v) { + !std::is_same_v && + !std::is_same_v) { collectVars(op.getReductionOperands(), values, hostToDevice); collectVars(op.getPrivateOperands(), values, hostToDevice); collectVars(op.getFirstprivateOperands(), values, hostToDevice); } } - for (auto p : values) - replaceAllUsesInAccComputeRegionsWith(std::get<0>(p), std::get<1>(p), - op.getRegion()); + if constexpr (std::is_same_v) { + assert(domInfo && postDomInfo && + "Dominance info required for DeclareEnterOp"); + replaceAllUsesInUnstructuredComputeRegionWith(op, values, *domInfo, + *postDomInfo); + } else { + for (auto p : values) { + replaceAllUsesInAccComputeRegionsWith(std::get<0>(p), std::get<1>(p), + op.getRegion()); + } + } } class LegalizeDataValuesInRegion @@ -105,10 +163,16 @@ class LegalizeDataValuesInRegion func::FuncOp funcOp = getOperation(); bool replaceHostVsDevice = this->hostToDevice.getValue(); + // Initialize dominance info + DominanceInfo domInfo; + PostDominanceInfo postDomInfo; + bool computedDomInfo = false; + funcOp.walk([&](Operation *op) { if (!isa(*op) && !(isa(*op) && - applyToAccDataConstruct)) + applyToAccDataConstruct) && + !isa(*op)) return; if (auto parallelOp = dyn_cast(*op)) { @@ -125,6 +189,14 @@ class LegalizeDataValuesInRegion collectAndReplaceInRegion(declareOp, replaceHostVsDevice); } else if (auto hostDataOp = dyn_cast(*op)) { collectAndReplaceInRegion(hostDataOp, replaceHostVsDevice); + } else if (auto declareEnterOp = dyn_cast(*op)) { + if (!computedDomInfo) { + domInfo = DominanceInfo(funcOp); + postDomInfo = PostDominanceInfo(funcOp); + computedDomInfo = true; + } + collectAndReplaceInRegion(declareEnterOp, replaceHostVsDevice, &domInfo, + &postDomInfo); } else { llvm_unreachable("unsupported acc region op"); } diff --git a/mlir/test/Dialect/OpenACC/legalize-data.mlir b/mlir/test/Dialect/OpenACC/legalize-data.mlir index 9461225e9a7e0..28ef6761a6ef4 100644 --- a/mlir/test/Dialect/OpenACC/legalize-data.mlir +++ b/mlir/test/Dialect/OpenACC/legalize-data.mlir @@ -245,4 +245,31 @@ func.func private @foo(memref<10xf32>) // CHECK: acc.host_data dataOperands(%[[USE_DEVICE]] : memref<10xf32>) { // DEVICE: func.call @foo(%[[USE_DEVICE]]) : (memref<10xf32>) -> () // CHECK: acc.terminator -// CHECK: } \ No newline at end of file +// CHECK: } + +// ----- + +func.func @test(%a: memref<10xf32>) { + %declare = acc.create varPtr(%a : memref<10xf32>) varType(tensor<10xf32>) -> memref<10xf32> {name = "arr"} + %token = acc.declare_enter dataOperands(%declare : memref<10xf32>) + acc.kernels dataOperands(%declare : memref<10xf32>) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1.000000e+00 : f32 + memref.store %c1, %a[%c0] : memref<10xf32> + acc.terminator + } + acc.declare_exit token(%token) dataOperands(%declare : memref<10xf32>) + return +} + +// CHECK-LABEL: func.func @test +// CHECK-SAME: (%[[A:.*]]: memref<10xf32>) +// CHECK: %[[DECLARE:.*]] = acc.create varPtr(%[[A]] : memref<10xf32>) varType(tensor<10xf32>) -> memref<10xf32> {name = "arr"} +// CHECK: %[[TOKEN:.*]] = acc.declare_enter dataOperands(%[[DECLARE]] : memref<10xf32>) +// CHECK: acc.kernels dataOperands(%[[DECLARE]] : memref<10xf32>) { +// DEVICE: memref.store %{{.*}}, %[[DECLARE]][%{{.*}}] : memref<10xf32> +// HOST: memref.store %{{.*}}, %[[A]][%{{.*}}] : memref<10xf32> +// CHECK: acc.terminator +// CHECK: } +// CHECK: acc.declare_exit token(%[[TOKEN]]) dataOperands(%[[DECLARE]] : memref<10xf32>) +// CHECK: return \ No newline at end of file