|
14 | 14 |
|
15 | 15 | #include "ClauseFinder.h" |
16 | 16 | #include "flang/Evaluate/fold.h" |
| 17 | +#include "flang/Evaluate/tools.h" |
17 | 18 | #include <flang/Lower/AbstractConverter.h> |
18 | 19 | #include <flang/Lower/ConvertType.h> |
19 | 20 | #include <flang/Lower/DirectivesCommon.h> |
20 | 21 | #include <flang/Lower/OpenMP/Clauses.h> |
21 | 22 | #include <flang/Lower/PFTBuilder.h> |
22 | 23 | #include <flang/Lower/Support/PrivateReductionUtils.h> |
| 24 | +#include <flang/Optimizer/Builder/BoxValue.h> |
23 | 25 | #include <flang/Optimizer/Builder/FIRBuilder.h> |
24 | 26 | #include <flang/Optimizer/Builder/Todo.h> |
| 27 | +#include <flang/Optimizer/HLFIR/HLFIROps.h> |
25 | 28 | #include <flang/Parser/openmp-utils.h> |
26 | 29 | #include <flang/Parser/parse-tree.h> |
27 | 30 | #include <flang/Parser/tools.h> |
28 | 31 | #include <flang/Semantics/tools.h> |
29 | 32 | #include <flang/Semantics/type.h> |
30 | 33 | #include <flang/Utils/OpenMP.h> |
| 34 | +#include <llvm/ADT/SmallPtrSet.h> |
| 35 | +#include <llvm/ADT/StringRef.h> |
31 | 36 | #include <llvm/Support/CommandLine.h> |
32 | 37 |
|
| 38 | +#include <functional> |
33 | 39 | #include <iterator> |
34 | 40 |
|
35 | 41 | template <typename T> |
@@ -61,6 +67,142 @@ namespace Fortran { |
61 | 67 | namespace lower { |
62 | 68 | namespace omp { |
63 | 69 |
|
| 70 | +mlir::FlatSymbolRefAttr getOrGenImplicitDefaultDeclareMapper( |
| 71 | + lower::AbstractConverter &converter, mlir::Location loc, |
| 72 | + fir::RecordType recordType, llvm::StringRef mapperNameStr) { |
| 73 | + if (mapperNameStr.empty()) |
| 74 | + return {}; |
| 75 | + |
| 76 | + if (converter.getModuleOp().lookupSymbol(mapperNameStr)) |
| 77 | + return mlir::FlatSymbolRefAttr::get(&converter.getMLIRContext(), |
| 78 | + mapperNameStr); |
| 79 | + |
| 80 | + fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); |
| 81 | + mlir::OpBuilder::InsertionGuard guard(firOpBuilder); |
| 82 | + |
| 83 | + firOpBuilder.setInsertionPointToStart(converter.getModuleOp().getBody()); |
| 84 | + auto declMapperOp = mlir::omp::DeclareMapperOp::create( |
| 85 | + firOpBuilder, loc, mapperNameStr, recordType); |
| 86 | + auto ®ion = declMapperOp.getRegion(); |
| 87 | + firOpBuilder.createBlock(®ion); |
| 88 | + auto mapperArg = region.addArgument(firOpBuilder.getRefType(recordType), loc); |
| 89 | + |
| 90 | + auto declareOp = hlfir::DeclareOp::create(firOpBuilder, loc, mapperArg, |
| 91 | + /*uniq_name=*/""); |
| 92 | + |
| 93 | + const auto genBoundsOps = [&](mlir::Value mapVal, |
| 94 | + llvm::SmallVectorImpl<mlir::Value> &bounds) { |
| 95 | + fir::ExtendedValue extVal = |
| 96 | + hlfir::translateToExtendedValue(mapVal.getLoc(), firOpBuilder, |
| 97 | + hlfir::Entity{mapVal}, |
| 98 | + /*contiguousHint=*/true) |
| 99 | + .first; |
| 100 | + fir::factory::AddrAndBoundsInfo info = fir::factory::getDataOperandBaseAddr( |
| 101 | + firOpBuilder, mapVal, /*isOptional=*/false, mapVal.getLoc()); |
| 102 | + bounds = fir::factory::genImplicitBoundsOps<mlir::omp::MapBoundsOp, |
| 103 | + mlir::omp::MapBoundsType>( |
| 104 | + firOpBuilder, info, extVal, |
| 105 | + /*dataExvIsAssumedSize=*/false, mapVal.getLoc()); |
| 106 | + }; |
| 107 | + |
| 108 | + const auto getFieldRef = [&](mlir::Value rec, llvm::StringRef fieldName, |
| 109 | + mlir::Type fieldTy, mlir::Type recType) { |
| 110 | + mlir::Value field = fir::FieldIndexOp::create( |
| 111 | + firOpBuilder, loc, fir::FieldType::get(recType.getContext()), fieldName, |
| 112 | + recType, fir::getTypeParams(rec)); |
| 113 | + return fir::CoordinateOp::create( |
| 114 | + firOpBuilder, loc, firOpBuilder.getRefType(fieldTy), rec, field); |
| 115 | + }; |
| 116 | + |
| 117 | + llvm::SmallVector<mlir::Value> clauseMapVars; |
| 118 | + llvm::SmallVector<llvm::SmallVector<int64_t>> memberPlacementIndices; |
| 119 | + llvm::SmallVector<mlir::Value> memberMapOps; |
| 120 | + |
| 121 | + mlir::omp::ClauseMapFlags mapFlag = mlir::omp::ClauseMapFlags::to | |
| 122 | + mlir::omp::ClauseMapFlags::from | |
| 123 | + mlir::omp::ClauseMapFlags::implicit; |
| 124 | + mlir::omp::VariableCaptureKind captureKind = |
| 125 | + mlir::omp::VariableCaptureKind::ByRef; |
| 126 | + |
| 127 | + for (const auto &entry : llvm::enumerate(recordType.getTypeList())) { |
| 128 | + const auto &memberName = entry.value().first; |
| 129 | + const auto &memberType = entry.value().second; |
| 130 | + mlir::FlatSymbolRefAttr mapperId; |
| 131 | + if (auto recType = mlir::dyn_cast<fir::RecordType>( |
| 132 | + fir::getFortranElementType(memberType))) { |
| 133 | + std::string mapperIdName = |
| 134 | + recType.getName().str() + llvm::omp::OmpDefaultMapperName; |
| 135 | + if (auto *sym = converter.getCurrentScope().FindSymbol(mapperIdName)) |
| 136 | + mapperIdName = converter.mangleName(mapperIdName, sym->owner()); |
| 137 | + else if (auto *memberSym = |
| 138 | + converter.getCurrentScope().FindSymbol(memberName)) |
| 139 | + mapperIdName = converter.mangleName(mapperIdName, memberSym->owner()); |
| 140 | + |
| 141 | + mapperId = getOrGenImplicitDefaultDeclareMapper(converter, loc, recType, |
| 142 | + mapperIdName); |
| 143 | + } |
| 144 | + |
| 145 | + auto ref = |
| 146 | + getFieldRef(declareOp.getBase(), memberName, memberType, recordType); |
| 147 | + llvm::SmallVector<mlir::Value> bounds; |
| 148 | + genBoundsOps(ref, bounds); |
| 149 | + mlir::Value mapOp = Fortran::utils::openmp::createMapInfoOp( |
| 150 | + firOpBuilder, loc, ref, /*varPtrPtr=*/mlir::Value{}, /*name=*/"", |
| 151 | + bounds, |
| 152 | + /*members=*/{}, |
| 153 | + /*membersIndex=*/mlir::ArrayAttr{}, mapFlag, captureKind, ref.getType(), |
| 154 | + /*partialMap=*/false, mapperId); |
| 155 | + memberMapOps.emplace_back(mapOp); |
| 156 | + memberPlacementIndices.emplace_back( |
| 157 | + llvm::SmallVector<int64_t>{(int64_t)entry.index()}); |
| 158 | + } |
| 159 | + |
| 160 | + llvm::SmallVector<mlir::Value> bounds; |
| 161 | + genBoundsOps(declareOp.getOriginalBase(), bounds); |
| 162 | + mlir::omp::ClauseMapFlags parentMapFlag = mlir::omp::ClauseMapFlags::implicit; |
| 163 | + mlir::omp::MapInfoOp mapOp = Fortran::utils::openmp::createMapInfoOp( |
| 164 | + firOpBuilder, loc, declareOp.getOriginalBase(), |
| 165 | + /*varPtrPtr=*/mlir::Value(), /*name=*/"", bounds, memberMapOps, |
| 166 | + firOpBuilder.create2DI64ArrayAttr(memberPlacementIndices), parentMapFlag, |
| 167 | + captureKind, declareOp.getType(0), |
| 168 | + /*partialMap=*/true); |
| 169 | + |
| 170 | + clauseMapVars.emplace_back(mapOp); |
| 171 | + mlir::omp::DeclareMapperInfoOp::create(firOpBuilder, loc, clauseMapVars); |
| 172 | + return mlir::FlatSymbolRefAttr::get(&converter.getMLIRContext(), |
| 173 | + mapperNameStr); |
| 174 | +} |
| 175 | + |
| 176 | +bool requiresImplicitDefaultDeclareMapper( |
| 177 | + const semantics::DerivedTypeSpec &typeSpec) { |
| 178 | + // ISO C interoperable types (e.g., c_ptr, c_funptr) must always have implicit |
| 179 | + // default mappers available so that OpenMP offloading can correctly map them. |
| 180 | + if (semantics::IsIsoCType(&typeSpec)) |
| 181 | + return true; |
| 182 | + |
| 183 | + llvm::SmallPtrSet<const semantics::DerivedTypeSpec *, 8> visited; |
| 184 | + |
| 185 | + std::function<bool(const semantics::DerivedTypeSpec &)> requiresMapper = |
| 186 | + [&](const semantics::DerivedTypeSpec &spec) -> bool { |
| 187 | + if (!visited.insert(&spec).second) |
| 188 | + return false; |
| 189 | + |
| 190 | + semantics::DirectComponentIterator directComponents{spec}; |
| 191 | + for (const semantics::Symbol &component : directComponents) { |
| 192 | + if (semantics::IsAllocatableOrPointer(component)) |
| 193 | + return true; |
| 194 | + |
| 195 | + if (const semantics::DeclTypeSpec *declType = component.GetType()) |
| 196 | + if (const auto *nested = declType->AsDerived()) |
| 197 | + if (requiresMapper(*nested)) |
| 198 | + return true; |
| 199 | + } |
| 200 | + return false; |
| 201 | + }; |
| 202 | + |
| 203 | + return requiresMapper(typeSpec); |
| 204 | +} |
| 205 | + |
64 | 206 | int64_t getCollapseValue(const List<Clause> &clauses) { |
65 | 207 | auto iter = llvm::find_if(clauses, [](const Clause &clause) { |
66 | 208 | return clause.id == llvm::omp::Clause::OMPC_collapse; |
@@ -537,6 +679,12 @@ void insertChildMapInfoIntoParent( |
537 | 679 | mapOperands[std::distance(mapSyms.begin(), parentIter)] |
538 | 680 | .getDefiningOp()); |
539 | 681 |
|
| 682 | + // Once explicit members are attached to a parent map, do not also invoke |
| 683 | + // a declare mapper on it, otherwise the mapper would remap the same |
| 684 | + // components leading to duplicate mappings at runtime. |
| 685 | + if (!indices.second.memberMap.empty() && mapOp.getMapperIdAttr()) |
| 686 | + mapOp.setMapperIdAttr(nullptr); |
| 687 | + |
540 | 688 | // NOTE: To maintain appropriate SSA ordering, we move the parent map |
541 | 689 | // which will now have references to its children after the last |
542 | 690 | // of its members to be generated. This is necessary when a user |
|
0 commit comments