|
20 | 20 | #include <flang/Lower/OpenMP/Clauses.h> |
21 | 21 | #include <flang/Lower/PFTBuilder.h> |
22 | 22 | #include <flang/Lower/Support/PrivateReductionUtils.h> |
| 23 | +#include <flang/Optimizer/Builder/BoxValue.h> |
23 | 24 | #include <flang/Optimizer/Builder/FIRBuilder.h> |
24 | 25 | #include <flang/Optimizer/Builder/Todo.h> |
| 26 | +#include <flang/Optimizer/HLFIR/HLFIROps.h> |
25 | 27 | #include <flang/Parser/openmp-utils.h> |
26 | 28 | #include <flang/Parser/parse-tree.h> |
27 | 29 | #include <flang/Parser/tools.h> |
28 | 30 | #include <flang/Semantics/tools.h> |
29 | 31 | #include <flang/Semantics/type.h> |
30 | 32 | #include <flang/Utils/OpenMP.h> |
| 33 | +#include <llvm/ADT/SmallPtrSet.h> |
| 34 | +#include <llvm/ADT/StringRef.h> |
31 | 35 | #include <llvm/Support/CommandLine.h> |
32 | 36 |
|
| 37 | +#include <functional> |
33 | 38 | #include <iterator> |
34 | 39 |
|
35 | 40 | template <typename T> |
@@ -61,6 +66,139 @@ namespace Fortran { |
61 | 66 | namespace lower { |
62 | 67 | namespace omp { |
63 | 68 |
|
| 69 | +mlir::FlatSymbolRefAttr getOrGenImplicitDefaultDeclareMapper( |
| 70 | + lower::AbstractConverter &converter, mlir::Location loc, |
| 71 | + fir::RecordType recordType, llvm::StringRef mapperNameStr) { |
| 72 | + if (converter.getModuleOp().lookupSymbol(mapperNameStr)) |
| 73 | + return mlir::FlatSymbolRefAttr::get(&converter.getMLIRContext(), |
| 74 | + mapperNameStr); |
| 75 | + |
| 76 | + fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); |
| 77 | + mlir::OpBuilder::InsertionGuard guard(firOpBuilder); |
| 78 | + |
| 79 | + firOpBuilder.setInsertionPointToStart(converter.getModuleOp().getBody()); |
| 80 | + auto declMapperOp = firOpBuilder.create<mlir::omp::DeclareMapperOp>( |
| 81 | + loc, mapperNameStr, recordType); |
| 82 | + auto ®ion = declMapperOp.getRegion(); |
| 83 | + firOpBuilder.createBlock(®ion); |
| 84 | + auto mapperArg = region.addArgument(firOpBuilder.getRefType(recordType), loc); |
| 85 | + |
| 86 | + auto declareOp = |
| 87 | + firOpBuilder.create<hlfir::DeclareOp>(loc, mapperArg, /*uniq_name=*/""); |
| 88 | + |
| 89 | + const auto genBoundsOps = [&](mlir::Value mapVal, |
| 90 | + llvm::SmallVectorImpl<mlir::Value> &bounds) { |
| 91 | + fir::ExtendedValue extVal = |
| 92 | + hlfir::translateToExtendedValue(mapVal.getLoc(), firOpBuilder, |
| 93 | + hlfir::Entity{mapVal}, |
| 94 | + /*contiguousHint=*/true) |
| 95 | + .first; |
| 96 | + fir::factory::AddrAndBoundsInfo info = fir::factory::getDataOperandBaseAddr( |
| 97 | + firOpBuilder, mapVal, /*isOptional=*/false, mapVal.getLoc()); |
| 98 | + bounds = fir::factory::genImplicitBoundsOps<mlir::omp::MapBoundsOp, |
| 99 | + mlir::omp::MapBoundsType>( |
| 100 | + firOpBuilder, info, extVal, |
| 101 | + /*dataExvIsAssumedSize=*/false, mapVal.getLoc()); |
| 102 | + }; |
| 103 | + |
| 104 | + const auto getFieldRef = [&](mlir::Value rec, llvm::StringRef fieldName, |
| 105 | + mlir::Type fieldTy, mlir::Type recType) { |
| 106 | + mlir::Value field = firOpBuilder.create<fir::FieldIndexOp>( |
| 107 | + loc, fir::FieldType::get(recType.getContext()), fieldName, recType, |
| 108 | + fir::getTypeParams(rec)); |
| 109 | + return firOpBuilder.create<fir::CoordinateOp>( |
| 110 | + loc, firOpBuilder.getRefType(fieldTy), rec, field); |
| 111 | + }; |
| 112 | + |
| 113 | + mlir::omp::DeclareMapperInfoOperands clauseOps; |
| 114 | + llvm::SmallVector<llvm::SmallVector<int64_t>> memberPlacementIndices; |
| 115 | + llvm::SmallVector<mlir::Value> memberMapOps; |
| 116 | + |
| 117 | + mlir::omp::ClauseMapFlags mapFlag = mlir::omp::ClauseMapFlags::to; |
| 118 | + mapFlag |= mlir::omp::ClauseMapFlags::from; |
| 119 | + mapFlag |= mlir::omp::ClauseMapFlags::implicit; |
| 120 | + mlir::omp::VariableCaptureKind captureKind = |
| 121 | + mlir::omp::VariableCaptureKind::ByRef; |
| 122 | + |
| 123 | + for (const auto &entry : llvm::enumerate(recordType.getTypeList())) { |
| 124 | + const auto &memberName = entry.value().first; |
| 125 | + const auto &memberType = entry.value().second; |
| 126 | + mlir::FlatSymbolRefAttr mapperId; |
| 127 | + if (auto recType = mlir::dyn_cast<fir::RecordType>( |
| 128 | + fir::getFortranElementType(memberType))) { |
| 129 | + std::string mapperIdName = |
| 130 | + recType.getName().str() + llvm::omp::OmpDefaultMapperName; |
| 131 | + if (auto *sym = converter.getCurrentScope().FindSymbol(mapperIdName)) |
| 132 | + mapperIdName = converter.mangleName(mapperIdName, sym->owner()); |
| 133 | + else if (auto *sym = converter.getCurrentScope().FindSymbol(memberName)) |
| 134 | + mapperIdName = converter.mangleName(mapperIdName, sym->owner()); |
| 135 | + |
| 136 | + mapperId = getOrGenImplicitDefaultDeclareMapper(converter, loc, recType, |
| 137 | + mapperIdName); |
| 138 | + } |
| 139 | + |
| 140 | + auto ref = |
| 141 | + getFieldRef(declareOp.getBase(), memberName, memberType, recordType); |
| 142 | + llvm::SmallVector<mlir::Value> bounds; |
| 143 | + genBoundsOps(ref, bounds); |
| 144 | + mlir::Value mapOp = Fortran::utils::openmp::createMapInfoOp( |
| 145 | + firOpBuilder, loc, ref, /*varPtrPtr=*/mlir::Value{}, /*name=*/"", |
| 146 | + bounds, |
| 147 | + /*members=*/{}, |
| 148 | + /*membersIndex=*/mlir::ArrayAttr{}, mapFlag, captureKind, ref.getType(), |
| 149 | + /*partialMap=*/false, mapperId); |
| 150 | + memberMapOps.emplace_back(mapOp); |
| 151 | + memberPlacementIndices.emplace_back( |
| 152 | + llvm::SmallVector<int64_t>{(int64_t)entry.index()}); |
| 153 | + } |
| 154 | + |
| 155 | + llvm::SmallVector<mlir::Value> bounds; |
| 156 | + genBoundsOps(declareOp.getOriginalBase(), bounds); |
| 157 | + mlir::omp::ClauseMapFlags parentMapFlag = mlir::omp::ClauseMapFlags::implicit; |
| 158 | + mlir::omp::MapInfoOp mapOp = Fortran::utils::openmp::createMapInfoOp( |
| 159 | + firOpBuilder, loc, declareOp.getOriginalBase(), |
| 160 | + /*varPtrPtr=*/mlir::Value(), /*name=*/"", bounds, memberMapOps, |
| 161 | + firOpBuilder.create2DI64ArrayAttr(memberPlacementIndices), parentMapFlag, |
| 162 | + captureKind, declareOp.getType(0), |
| 163 | + /*partialMap=*/true); |
| 164 | + |
| 165 | + clauseOps.mapVars.emplace_back(mapOp); |
| 166 | + firOpBuilder.create<mlir::omp::DeclareMapperInfoOp>(loc, clauseOps.mapVars); |
| 167 | + return mlir::FlatSymbolRefAttr::get(&converter.getMLIRContext(), |
| 168 | + mapperNameStr); |
| 169 | +} |
| 170 | + |
| 171 | +bool requiresImplicitDefaultDeclareMapper( |
| 172 | + const semantics::DerivedTypeSpec &typeSpec) { |
| 173 | + std::string rawName = typeSpec.name().ToString(); |
| 174 | + std::string loweredName = llvm::StringRef(rawName).lower(); |
| 175 | + llvm::StringRef typeNameRef(loweredName); |
| 176 | + if (typeNameRef.contains("c_ptr") || typeNameRef.contains("c_funptr")) |
| 177 | + return true; |
| 178 | + |
| 179 | + llvm::SmallPtrSet<const semantics::DerivedTypeSpec *, 8> visited; |
| 180 | + |
| 181 | + std::function<bool(const semantics::DerivedTypeSpec &)> requiresMapper = |
| 182 | + [&](const semantics::DerivedTypeSpec &spec) -> bool { |
| 183 | + if (!visited.insert(&spec).second) |
| 184 | + return false; |
| 185 | + |
| 186 | + semantics::DirectComponentIterator directComponents{spec}; |
| 187 | + for (const semantics::Symbol &component : directComponents) { |
| 188 | + if (semantics::IsAllocatableOrPointer(component)) |
| 189 | + return true; |
| 190 | + |
| 191 | + if (const semantics::DeclTypeSpec *declType = component.GetType()) |
| 192 | + if (const auto *nested = declType->AsDerived()) |
| 193 | + if (requiresMapper(*nested)) |
| 194 | + return true; |
| 195 | + } |
| 196 | + return false; |
| 197 | + }; |
| 198 | + |
| 199 | + return requiresMapper(typeSpec); |
| 200 | +} |
| 201 | + |
64 | 202 | int64_t getCollapseValue(const List<Clause> &clauses) { |
65 | 203 | auto iter = llvm::find_if(clauses, [](const Clause &clause) { |
66 | 204 | return clause.id == llvm::omp::Clause::OMPC_collapse; |
@@ -537,6 +675,12 @@ void insertChildMapInfoIntoParent( |
537 | 675 | mapOperands[std::distance(mapSyms.begin(), parentIter)] |
538 | 676 | .getDefiningOp()); |
539 | 677 |
|
| 678 | + // Once explicit members are attached to a parent map, do not also invoke |
| 679 | + // a declare mapper on it, otherwise the mapper would remap the same |
| 680 | + // components leading to duplicate mappings at runtime. |
| 681 | + if (!indices.second.memberMap.empty() && mapOp.getMapperIdAttr()) |
| 682 | + mapOp.setMapperIdAttr(nullptr); |
| 683 | + |
540 | 684 | // NOTE: To maintain appropriate SSA ordering, we move the parent map |
541 | 685 | // which will now have references to its children after the last |
542 | 686 | // of its members to be generated. This is necessary when a user |
|
0 commit comments