diff --git a/flang/test/Fir/OpenACC/legalize-data.fir b/flang/test/Fir/OpenACC/legalize-data.fir index 3b8695434e6e4..6bc81dc08db30 100644 --- a/flang/test/Fir/OpenACC/legalize-data.fir +++ b/flang/test/Fir/OpenACC/legalize-data.fir @@ -1,4 +1,4 @@ -// RUN: fir-opt -split-input-file --openacc-legalize-data %s | FileCheck %s +// RUN: fir-opt -split-input-file --openacc-legalize-data-values %s | FileCheck %s func.func @_QPsub1(%arg0: !fir.ref {fir.bindc_name = "i"}) { %0:2 = hlfir.declare %arg0 {uniq_name = "_QFsub1Ei"} : (!fir.ref) -> (!fir.ref, !fir.ref) @@ -22,3 +22,36 @@ func.func @_QPsub1(%arg0: !fir.ref {fir.bindc_name = "i"}) { // CHECK: acc.yield // CHECK: } // CHECK: acc.copyout accPtr(%[[COPYIN]] : !fir.ref) to varPtr(%[[I]]#0 : !fir.ref) {dataClause = #acc, name = "i"} + +// ----- + +func.func @_QPsub1(%arg0: !fir.ref {fir.bindc_name = "i"}) { + %0:2 = hlfir.declare %arg0 {uniq_name = "_QFsub1Ei"} : (!fir.ref) -> (!fir.ref, !fir.ref) + %1 = acc.copyin varPtr(%0#0 : !fir.ref) -> !fir.ref {dataClause = #acc, name = "i"} + acc.data dataOperands(%1 : !fir.ref) { + %c0_i32 = arith.constant 0 : i32 + hlfir.assign %c0_i32 to %0#0 : i32, !fir.ref + acc.serial { + hlfir.assign %c0_i32 to %0#0 : i32, !fir.ref + acc.yield + } + acc.terminator + } + acc.copyout accPtr(%1 : !fir.ref) to varPtr(%0#0 : !fir.ref) {dataClause = #acc, name = "i"} + return +} + +// CHECK-LABEL: func.func @_QPsub1 +// CHECK-SAME: (%[[ARG0:.*]]: !fir.ref {fir.bindc_name = "i"}) +// CHECK: %[[I:.*]]:2 = hlfir.declare %[[ARG0]] {uniq_name = "_QFsub1Ei"} : (!fir.ref) -> (!fir.ref, !fir.ref) +// CHECK: %[[COPYIN:.*]] = acc.copyin varPtr(%[[I]]#0 : !fir.ref) -> !fir.ref {dataClause = #acc, name = "i"} +// CHECK: acc.data dataOperands(%[[COPYIN]] : !fir.ref) { +// CHECK: %[[C0:.*]] = arith.constant 0 : i32 +// CHECK: hlfir.assign %[[C0]] to %0#0 : i32, !fir.ref +// CHECK: acc.serial { +// CHECK: hlfir.assign %[[C0]] to %[[COPYIN]] : i32, !fir.ref +// CHECK: acc.yield +// CHECK: } +// CHECK: acc.terminator +// CHECK: } +// CHECK: acc.copyout accPtr(%[[COPYIN]] : !fir.ref) to varPtr(%[[I]]#0 : !fir.ref) {dataClause = #acc, name = "i"} diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACC.h b/mlir/include/mlir/Dialect/OpenACC/OpenACC.h index ca96ce62ae404..cda07d6a91364 100644 --- a/mlir/include/mlir/Dialect/OpenACC/OpenACC.h +++ b/mlir/include/mlir/Dialect/OpenACC/OpenACC.h @@ -56,14 +56,14 @@ mlir::acc::ParallelOp, mlir::acc::KernelsOp, mlir::acc::SerialOp #define ACC_COMPUTE_CONSTRUCT_AND_LOOP_OPS \ ACC_COMPUTE_CONSTRUCT_OPS, mlir::acc::LoopOp -#define OPENACC_DATA_CONSTRUCT_STRUCTURED_OPS \ +#define ACC_DATA_CONSTRUCT_STRUCTURED_OPS \ mlir::acc::DataOp, mlir::acc::DeclareOp #define ACC_DATA_CONSTRUCT_UNSTRUCTURED_OPS \ mlir::acc::EnterDataOp, mlir::acc::ExitDataOp, mlir::acc::UpdateOp, \ mlir::acc::HostDataOp, mlir::acc::DeclareEnterOp, \ mlir::acc::DeclareExitOp #define ACC_DATA_CONSTRUCT_OPS \ - OPENACC_DATA_CONSTRUCT_STRUCTURED_OPS, ACC_DATA_CONSTRUCT_UNSTRUCTURED_OPS + ACC_DATA_CONSTRUCT_STRUCTURED_OPS, ACC_DATA_CONSTRUCT_UNSTRUCTURED_OPS #define ACC_COMPUTE_AND_DATA_CONSTRUCT_OPS \ ACC_COMPUTE_CONSTRUCT_OPS, ACC_DATA_CONSTRUCT_OPS #define ACC_COMPUTE_LOOP_AND_DATA_CONSTRUCT_OPS \ diff --git a/mlir/include/mlir/Dialect/OpenACC/Transforms/Passes.h b/mlir/include/mlir/Dialect/OpenACC/Transforms/Passes.h index bb93c78bf6ead..57d532b078b9e 100644 --- a/mlir/include/mlir/Dialect/OpenACC/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/OpenACC/Transforms/Passes.h @@ -11,9 +11,6 @@ #include "mlir/Pass/Pass.h" -#define GEN_PASS_DECL -#include "mlir/Dialect/OpenACC/Transforms/Passes.h.inc" - namespace mlir { namespace func { @@ -22,8 +19,8 @@ class FuncOp; namespace acc { -/// Create a pass to replace ssa values in region with device/host values. -std::unique_ptr> createLegalizeDataInRegion(); +#define GEN_PASS_DECL +#include "mlir/Dialect/OpenACC/Transforms/Passes.h.inc" /// Generate the code for registering conversion passes. #define GEN_PASS_REGISTRATION diff --git a/mlir/include/mlir/Dialect/OpenACC/Transforms/Passes.td b/mlir/include/mlir/Dialect/OpenACC/Transforms/Passes.td index abbc27765e342..9ceb91e5679a1 100644 --- a/mlir/include/mlir/Dialect/OpenACC/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/OpenACC/Transforms/Passes.td @@ -11,18 +11,20 @@ include "mlir/Pass/PassBase.td" -def LegalizeDataInRegion : Pass<"openacc-legalize-data", "mlir::func::FuncOp"> { - let summary = "Legalize the data in the compute region"; +def LegalizeDataValuesInRegion : Pass<"openacc-legalize-data-values", "mlir::func::FuncOp"> { + let summary = "Legalizes SSA values in compute regions with results from data clause operations"; let description = [{ - This pass replace uses of varPtr in the compute region with their accPtr - gathered from the data clause operands. + This pass replace uses of the `varPtr` in compute regions (kernels, + parallel, serial) with the result of data clause operations (`accPtr`). }]; let options = [ Option<"hostToDevice", "host-to-device", "bool", "true", "Replace varPtr uses with accPtr if true. Replace accPtr uses with " - "varPtr if false"> + "varPtr if false">, + Option<"applyToAccDataConstruct", "apply-to-acc-data-construct", "bool", "true", + "Replaces varPtr uses with accPtr for acc compute regions contained " + "within acc.data or acc.declare region."> ]; - let constructor = "::mlir::acc::createLegalizeDataInRegion()"; } #endif // MLIR_DIALECT_OPENACC_TRANSFORMS_PASSES diff --git a/mlir/lib/Dialect/OpenACC/Transforms/CMakeLists.txt b/mlir/lib/Dialect/OpenACC/Transforms/CMakeLists.txt index 41ba7f8f53d36..7d934956089a5 100644 --- a/mlir/lib/Dialect/OpenACC/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/OpenACC/Transforms/CMakeLists.txt @@ -1,5 +1,5 @@ add_mlir_dialect_library(MLIROpenACCTransforms - LegalizeData.cpp + LegalizeDataValues.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/OpenACC diff --git a/mlir/lib/Dialect/OpenACC/Transforms/LegalizeData.cpp b/mlir/lib/Dialect/OpenACC/Transforms/LegalizeDataValues.cpp similarity index 54% rename from mlir/lib/Dialect/OpenACC/Transforms/LegalizeData.cpp rename to mlir/lib/Dialect/OpenACC/Transforms/LegalizeDataValues.cpp index db6b472ff9733..4038e333adb8b 100644 --- a/mlir/lib/Dialect/OpenACC/Transforms/LegalizeData.cpp +++ b/mlir/lib/Dialect/OpenACC/Transforms/LegalizeDataValues.cpp @@ -1,4 +1,4 @@ -//===- LegalizeData.cpp - -------------------------------------------------===// +//===- LegalizeDataValues.cpp - -------------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -12,10 +12,11 @@ #include "mlir/Dialect/OpenACC/OpenACC.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/RegionUtils.h" +#include "llvm/Support/ErrorHandling.h" namespace mlir { namespace acc { -#define GEN_PASS_DEF_LEGALIZEDATAINREGION +#define GEN_PASS_DEF_LEGALIZEDATAVALUESINREGION #include "mlir/Dialect/OpenACC/Transforms/Passes.h.inc" } // namespace acc } // namespace mlir @@ -24,6 +25,17 @@ using namespace mlir; namespace { +static bool insideAccComputeRegion(mlir::Operation *op) { + mlir::Operation *parent{op->getParentOp()}; + while (parent) { + if (isa(parent)) { + return true; + } + parent = parent->getParentOp(); + } + return false; +} + static void collectPtrs(mlir::ValueRange operands, llvm::SmallVector> &values, bool hostToDevice) { @@ -39,6 +51,25 @@ static void collectPtrs(mlir::ValueRange operands, } } +template +static void replaceAllUsesInAccComputeRegionsWith(Value orig, Value replacement, + Region &outerRegion) { + for (auto &use : llvm::make_early_inc_range(orig.getUses())) { + if (outerRegion.isAncestor(use.getOwner()->getParentRegion())) { + if constexpr (std::is_same_v || + std::is_same_v) { + // For data construct regions, only replace uses in contained compute + // regions. + if (insideAccComputeRegion(use.getOwner())) { + use.set(replacement); + } + } else { + use.set(replacement); + } + } + } +} + template static void collectAndReplaceInRegion(Op &op, bool hostToDevice) { llvm::SmallVector> values; @@ -48,7 +79,9 @@ static void collectAndReplaceInRegion(Op &op, bool hostToDevice) { collectPtrs(op.getPrivateOperands(), values, hostToDevice); } else { collectPtrs(op.getDataClauseOperands(), values, hostToDevice); - if constexpr (!std::is_same_v) { + if constexpr (!std::is_same_v && + !std::is_same_v && + !std::is_same_v) { collectPtrs(op.getReductionOperands(), values, hostToDevice); collectPtrs(op.getGangPrivateOperands(), values, hostToDevice); collectPtrs(op.getGangFirstPrivateOperands(), values, hostToDevice); @@ -56,18 +89,25 @@ static void collectAndReplaceInRegion(Op &op, bool hostToDevice) { } for (auto p : values) - replaceAllUsesInRegionWith(std::get<0>(p), std::get<1>(p), op.getRegion()); + replaceAllUsesInAccComputeRegionsWith(std::get<0>(p), std::get<1>(p), + op.getRegion()); } -struct LegalizeDataInRegion - : public acc::impl::LegalizeDataInRegionBase { +class LegalizeDataValuesInRegion + : public acc::impl::LegalizeDataValuesInRegionBase< + LegalizeDataValuesInRegion> { +public: + using LegalizeDataValuesInRegionBase< + LegalizeDataValuesInRegion>::LegalizeDataValuesInRegionBase; void runOnOperation() override { func::FuncOp funcOp = getOperation(); bool replaceHostVsDevice = this->hostToDevice.getValue(); funcOp.walk([&](Operation *op) { - if (!isa(*op) && !isa(*op)) + if (!isa(*op) && + !(isa(*op) && + applyToAccDataConstruct)) return; if (auto parallelOp = dyn_cast(*op)) { @@ -78,14 +118,15 @@ struct LegalizeDataInRegion collectAndReplaceInRegion(kernelsOp, replaceHostVsDevice); } else if (auto loopOp = dyn_cast(*op)) { collectAndReplaceInRegion(loopOp, replaceHostVsDevice); + } else if (auto dataOp = dyn_cast(*op)) { + collectAndReplaceInRegion(dataOp, replaceHostVsDevice); + } else if (auto declareOp = dyn_cast(*op)) { + collectAndReplaceInRegion(declareOp, replaceHostVsDevice); + } else { + llvm_unreachable("unsupported acc region op"); } }); } }; } // end anonymous namespace - -std::unique_ptr> -mlir::acc::createLegalizeDataInRegion() { - return std::make_unique(); -} diff --git a/mlir/test/Dialect/OpenACC/legalize-data.mlir b/mlir/test/Dialect/OpenACC/legalize-data.mlir index 113fe90450ab7..842f8e260c499 100644 --- a/mlir/test/Dialect/OpenACC/legalize-data.mlir +++ b/mlir/test/Dialect/OpenACC/legalize-data.mlir @@ -1,5 +1,5 @@ -// RUN: mlir-opt -split-input-file --openacc-legalize-data %s | FileCheck %s --check-prefixes=CHECK,DEVICE -// RUN: mlir-opt -split-input-file --openacc-legalize-data=host-to-device=false %s | FileCheck %s --check-prefixes=CHECK,HOST +// RUN: mlir-opt -split-input-file --openacc-legalize-data-values %s | FileCheck %s --check-prefixes=CHECK,DEVICE +// RUN: mlir-opt -split-input-file --openacc-legalize-data-values=host-to-device=false %s | FileCheck %s --check-prefixes=CHECK,HOST func.func @test(%a: memref<10xf32>, %i : index) { %create = acc.create varPtr(%a : memref<10xf32>) -> memref<10xf32> @@ -61,6 +61,32 @@ func.func @test(%a: memref<10xf32>, %i : index) { // ----- +func.func @test(%a: memref<10xf32>, %i : index) { + %create = acc.create varPtr(%a : memref<10xf32>) -> memref<10xf32> + acc.data dataOperands(%create : memref<10xf32>) { + %c0 = arith.constant 0.000000e+00 : f32 + memref.store %c0, %a[%i] : memref<10xf32> + acc.serial { + %cs = memref.load %a[%i] : memref<10xf32> + acc.yield + } + acc.terminator + } + return +} + +// CHECK-LABEL: func.func @test +// CHECK-SAME: (%[[A:.*]]: memref<10xf32>, %[[I:.*]]: index) +// CHECK: %[[CREATE:.*]] = acc.create varPtr(%[[A]] : memref<10xf32>) -> memref<10xf32> +// CHECK: acc.data dataOperands(%[[CREATE]] : memref<10xf32>) { +// CHECK: memref.store %{{.*}}, %[[A]][%[[I]]] : memref<10xf32> +// DEVICE: %{{.*}} = memref.load %[[CREATE]][%[[I]]] : memref<10xf32> +// HOST: %{{.*}} = memref.load %[[A]][%[[I]]] : memref<10xf32> +// CHECK: acc.terminator +// CHECK: } + +// ----- + func.func @test(%a: memref<10xf32>) { %lb = arith.constant 0 : index %st = arith.constant 1 : index