Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 87 additions & 26 deletions flang/lib/Lower/OpenMP/ClauseProcessor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "flang/Lower/OpenMP/Clauses.h"
#include "flang/Lower/PFTBuilder.h"
#include "flang/Lower/Support/ReductionProcessor.h"
#include "flang/Optimizer/Dialect/FIRType.h"
#include "flang/Parser/tools.h"
#include "flang/Semantics/tools.h"
#include "flang/Utils/OpenMP.h"
Expand Down Expand Up @@ -1223,26 +1224,66 @@ void ClauseProcessor::processMapObjects(
llvm::StringRef mapperIdNameRef) const {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();

auto getDefaultMapperID = [&](const omp::Object &object,
std::string &mapperIdName) {
if (!mlir::isa<mlir::omp::DeclareMapperOp>(
firOpBuilder.getRegion().getParentOp())) {
const semantics::DerivedTypeSpec *typeSpec = nullptr;
auto getSymbolDerivedType = [](const semantics::Symbol &symbol)
-> const semantics::DerivedTypeSpec * {
const semantics::Symbol &ultimate = symbol.GetUltimate();
if (const semantics::DeclTypeSpec *declType = ultimate.GetType())
if (const auto *derived = declType->AsDerived())
return derived;
return nullptr;
};

if (object.sym()->owner().IsDerivedType())
typeSpec = object.sym()->owner().derivedTypeSpec();
else if (object.sym()->GetType() &&
object.sym()->GetType()->category() ==
semantics::DeclTypeSpec::TypeDerived)
typeSpec = &object.sym()->GetType()->derivedTypeSpec();

if (typeSpec) {
mapperIdName =
typeSpec->name().ToString() + llvm::omp::OmpDefaultMapperName;
if (auto *sym = converter.getCurrentScope().FindSymbol(mapperIdName))
mapperIdName = converter.mangleName(mapperIdName, sym->owner());
}
auto addImplicitMapper = [&](const omp::Object &object,
std::string &mapperIdName,
bool allowGenerate) -> mlir::FlatSymbolRefAttr {
if (mapperIdName.empty())
return mlir::FlatSymbolRefAttr();

if (converter.getModuleOp().lookupSymbol(mapperIdName))
return mlir::FlatSymbolRefAttr::get(&converter.getMLIRContext(),
mapperIdName);

if (!allowGenerate)
return mlir::FlatSymbolRefAttr();

const semantics::DerivedTypeSpec *typeSpec =
getSymbolDerivedType(*object.sym());
if (!typeSpec && object.sym()->owner().IsDerivedType())
typeSpec = object.sym()->owner().derivedTypeSpec();

if (!typeSpec)
return mlir::FlatSymbolRefAttr();

mlir::Type type = converter.genType(*typeSpec);
auto recordType = mlir::dyn_cast<fir::RecordType>(type);
if (!recordType)
return mlir::FlatSymbolRefAttr();

return getOrGenImplicitDefaultDeclareMapper(converter, clauseLocation,
recordType, mapperIdName);
};

auto getDefaultMapperID =
[&](const semantics::DerivedTypeSpec *typeSpec) -> std::string {
if (mlir::isa<mlir::omp::DeclareMapperOp>(
firOpBuilder.getRegion().getParentOp()) ||
!typeSpec)
return {};

std::string mapperIdName =
typeSpec->name().ToString() + llvm::omp::OmpDefaultMapperName;
if (auto *sym = converter.getCurrentScope().FindSymbol(mapperIdName)) {
mapperIdName = converter.mangleName(mapperIdName, sym->owner());
} else {
mapperIdName = converter.mangleName(mapperIdName, *typeSpec->GetScope());
}

// Make sure we don't return a mapper to self.
if (auto declMapOp = mlir::dyn_cast<mlir::omp::DeclareMapperOp>(
firOpBuilder.getRegion().getParentOp()))
if (mapperIdName == declMapOp.getSymName())
return {};
return mapperIdName;
};

// Create the mapper symbol from its name, if specified.
Expand All @@ -1251,8 +1292,13 @@ void ClauseProcessor::processMapObjects(
mapperIdNameRef != "__implicit_mapper") {
std::string mapperIdName = mapperIdNameRef.str();
const omp::Object &object = objects.front();
if (mapperIdNameRef == "default")
getDefaultMapperID(object, mapperIdName);
if (mapperIdNameRef == "default") {
const semantics::DerivedTypeSpec *typeSpec =
getSymbolDerivedType(*object.sym());
if (!typeSpec && object.sym()->owner().IsDerivedType())
typeSpec = object.sym()->owner().derivedTypeSpec();
mapperIdName = getDefaultMapperID(typeSpec);
}
assert(converter.getModuleOp().lookupSymbol(mapperIdName) &&
"mapper not found");
mapperId =
Expand Down Expand Up @@ -1290,13 +1336,28 @@ void ClauseProcessor::processMapObjects(
}
}

const semantics::DerivedTypeSpec *objectTypeSpec =
getSymbolDerivedType(*object.sym());

if (mapperIdNameRef == "__implicit_mapper") {
std::string mapperIdName;
getDefaultMapperID(object, mapperIdName);
mapperId = converter.getModuleOp().lookupSymbol(mapperIdName)
? mlir::FlatSymbolRefAttr::get(&converter.getMLIRContext(),
mapperIdName)
: mlir::FlatSymbolRefAttr();
if (parentObj.has_value()) {
mapperId = mlir::FlatSymbolRefAttr();
} else if (objectTypeSpec) {
std::string mapperIdName = getDefaultMapperID(objectTypeSpec);
bool needsDefaultMapper =
semantics::IsAllocatableOrObjectPointer(object.sym()) ||
(objectTypeSpec &&
requiresImplicitDefaultDeclareMapper(*objectTypeSpec));
bool containsDelete = (mapTypeBits & mlir::omp::ClauseMapFlags::del) !=
mlir::omp::ClauseMapFlags::none;
if (!mapperIdName.empty() && !containsDelete)
mapperId = addImplicitMapper(object, mapperIdName,
/*allowGenerate=*/needsDefaultMapper);
else
mapperId = mlir::FlatSymbolRefAttr();
} else {
mapperId = mlir::FlatSymbolRefAttr();
}
}

// Explicit map captures are captured ByRef by default,
Expand Down
47 changes: 35 additions & 12 deletions flang/lib/Lower/OpenMP/OpenMP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2578,18 +2578,6 @@ genTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
fir::ExtendedValue dataExv = converter.getSymbolExtendedValue(sym);
name << sym.name().ToString();

mlir::FlatSymbolRefAttr mapperId;
if (sym.GetType()->category() == semantics::DeclTypeSpec::TypeDerived) {
auto &typeSpec = sym.GetType()->derivedTypeSpec();
std::string mapperIdName =
typeSpec.name().ToString() + llvm::omp::OmpDefaultMapperName;
if (auto *sym = converter.getCurrentScope().FindSymbol(mapperIdName))
mapperIdName = converter.mangleName(mapperIdName, sym->owner());
if (converter.getModuleOp().lookupSymbol(mapperIdName))
mapperId = mlir::FlatSymbolRefAttr::get(&converter.getMLIRContext(),
mapperIdName);
}

fir::factory::AddrAndBoundsInfo info =
Fortran::lower::getDataOperandBaseAddr(
converter, firOpBuilder, sym.GetUltimate(),
Expand All @@ -2609,6 +2597,41 @@ genTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
mapFlagAndKind = getImplicitMapTypeAndKind(
firOpBuilder, converter, defaultMaps, eleType, loc, sym);

mlir::FlatSymbolRefAttr mapperId;
if (defaultMaps.empty()) {
const semantics::DerivedTypeSpec *typeSpec =
sym.GetType() ? sym.GetType()->AsDerived() : nullptr;
if (typeSpec) {
std::string mapperIdName =
typeSpec->name().ToString() + llvm::omp::OmpDefaultMapperName;
if (auto *mapperSym =
converter.getCurrentScope().FindSymbol(mapperIdName))
mapperIdName =
converter.mangleName(mapperIdName, mapperSym->owner());
else
mapperIdName =
converter.mangleName(mapperIdName, *typeSpec->GetScope());

if (!mapperIdName.empty()) {
bool allowImplicitMapper =
semantics::IsAllocatableOrObjectPointer(&sym);
bool hasDefaultMapper =
converter.getModuleOp().lookupSymbol(mapperIdName);
if (hasDefaultMapper || allowImplicitMapper) {
if (!hasDefaultMapper) {
if (auto recordType = mlir::dyn_cast_or_null<fir::RecordType>(
converter.genType(*typeSpec)))
mapperId = getOrGenImplicitDefaultDeclareMapper(
converter, loc, recordType, mapperIdName);
} else {
mapperId = mlir::FlatSymbolRefAttr::get(
&converter.getMLIRContext(), mapperIdName);
}
}
}
}
}

mlir::Value mapOp = createMapInfoOp(
firOpBuilder, converter.getCurrentLocation(), baseOp,
/*varPtrPtr=*/mlir::Value{}, name.str(), bounds, /*members=*/{},
Expand Down
150 changes: 150 additions & 0 deletions flang/lib/Lower/OpenMP/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,22 +14,28 @@

#include "ClauseFinder.h"
#include "flang/Evaluate/fold.h"
#include "flang/Evaluate/tools.h"
#include <flang/Lower/AbstractConverter.h>
#include <flang/Lower/ConvertType.h>
#include <flang/Lower/DirectivesCommon.h>
#include <flang/Lower/OpenMP/Clauses.h>
#include <flang/Lower/PFTBuilder.h>
#include <flang/Lower/Support/PrivateReductionUtils.h>
#include <flang/Optimizer/Builder/BoxValue.h>
#include <flang/Optimizer/Builder/FIRBuilder.h>
#include <flang/Optimizer/Builder/Todo.h>
#include <flang/Optimizer/HLFIR/HLFIROps.h>
#include <flang/Parser/openmp-utils.h>
#include <flang/Parser/parse-tree.h>
#include <flang/Parser/tools.h>
#include <flang/Semantics/tools.h>
#include <flang/Semantics/type.h>
#include <flang/Utils/OpenMP.h>
#include <llvm/ADT/SmallPtrSet.h>
#include <llvm/ADT/StringRef.h>
#include <llvm/Support/CommandLine.h>

#include <functional>
#include <iterator>

template <typename T>
Expand Down Expand Up @@ -61,6 +67,144 @@ namespace Fortran {
namespace lower {
namespace omp {

mlir::FlatSymbolRefAttr getOrGenImplicitDefaultDeclareMapper(
lower::AbstractConverter &converter, mlir::Location loc,
fir::RecordType recordType, llvm::StringRef mapperNameStr) {
if (mapperNameStr.empty())
return {};

if (converter.getModuleOp().lookupSymbol(mapperNameStr))
return mlir::FlatSymbolRefAttr::get(&converter.getMLIRContext(),
mapperNameStr);

fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
mlir::OpBuilder::InsertionGuard guard(firOpBuilder);

firOpBuilder.setInsertionPointToStart(converter.getModuleOp().getBody());
auto declMapperOp = mlir::omp::DeclareMapperOp::create(
firOpBuilder, loc, mapperNameStr, recordType);
auto &region = declMapperOp.getRegion();
firOpBuilder.createBlock(&region);
auto mapperArg = region.addArgument(firOpBuilder.getRefType(recordType), loc);

auto declareOp = hlfir::DeclareOp::create(firOpBuilder, loc, mapperArg,
/*uniq_name=*/"");

const auto genBoundsOps = [&](mlir::Value mapVal,
llvm::SmallVectorImpl<mlir::Value> &bounds) {
fir::ExtendedValue extVal =
hlfir::translateToExtendedValue(mapVal.getLoc(), firOpBuilder,
hlfir::Entity{mapVal},
/*contiguousHint=*/true)
.first;
fir::factory::AddrAndBoundsInfo info = fir::factory::getDataOperandBaseAddr(
firOpBuilder, mapVal, /*isOptional=*/false, mapVal.getLoc());
bounds = fir::factory::genImplicitBoundsOps<mlir::omp::MapBoundsOp,
mlir::omp::MapBoundsType>(
firOpBuilder, info, extVal,
/*dataExvIsAssumedSize=*/false, mapVal.getLoc());
};

const auto getFieldRef = [&](mlir::Value rec, llvm::StringRef fieldName,
mlir::Type fieldTy, mlir::Type recType) {
mlir::Value field = fir::FieldIndexOp::create(
firOpBuilder, loc, fir::FieldType::get(recType.getContext()), fieldName,
recType, fir::getTypeParams(rec));
return fir::CoordinateOp::create(
firOpBuilder, loc, firOpBuilder.getRefType(fieldTy), rec, field);
};

llvm::SmallVector<mlir::Value> clauseMapVars;
llvm::SmallVector<llvm::SmallVector<int64_t>> memberPlacementIndices;
llvm::SmallVector<mlir::Value> memberMapOps;

mlir::omp::ClauseMapFlags mapFlag = mlir::omp::ClauseMapFlags::to;
mapFlag |= mlir::omp::ClauseMapFlags::from;
mapFlag |= mlir::omp::ClauseMapFlags::implicit;
mlir::omp::VariableCaptureKind captureKind =
mlir::omp::VariableCaptureKind::ByRef;

for (const auto &entry : llvm::enumerate(recordType.getTypeList())) {
const auto &memberName = entry.value().first;
const auto &memberType = entry.value().second;
mlir::FlatSymbolRefAttr mapperId;
if (auto recType = mlir::dyn_cast<fir::RecordType>(
fir::getFortranElementType(memberType))) {
std::string mapperIdName =
recType.getName().str() + llvm::omp::OmpDefaultMapperName;
if (auto *sym = converter.getCurrentScope().FindSymbol(mapperIdName))
mapperIdName = converter.mangleName(mapperIdName, sym->owner());
else if (auto *memberSym =
converter.getCurrentScope().FindSymbol(memberName))
mapperIdName = converter.mangleName(mapperIdName, memberSym->owner());

mapperId = getOrGenImplicitDefaultDeclareMapper(converter, loc, recType,
mapperIdName);
}

auto ref =
getFieldRef(declareOp.getBase(), memberName, memberType, recordType);
llvm::SmallVector<mlir::Value> bounds;
genBoundsOps(ref, bounds);
mlir::Value mapOp = Fortran::utils::openmp::createMapInfoOp(
firOpBuilder, loc, ref, /*varPtrPtr=*/mlir::Value{}, /*name=*/"",
bounds,
/*members=*/{},
/*membersIndex=*/mlir::ArrayAttr{}, mapFlag, captureKind, ref.getType(),
/*partialMap=*/false, mapperId);
memberMapOps.emplace_back(mapOp);
memberPlacementIndices.emplace_back(
llvm::SmallVector<int64_t>{(int64_t)entry.index()});
}

llvm::SmallVector<mlir::Value> bounds;
genBoundsOps(declareOp.getOriginalBase(), bounds);
mlir::omp::ClauseMapFlags parentMapFlag = mlir::omp::ClauseMapFlags::to;
parentMapFlag |= mlir::omp::ClauseMapFlags::from;
parentMapFlag |= mlir::omp::ClauseMapFlags::implicit;
mlir::omp::MapInfoOp mapOp = Fortran::utils::openmp::createMapInfoOp(
firOpBuilder, loc, declareOp.getOriginalBase(),
/*varPtrPtr=*/mlir::Value(), /*name=*/"", bounds, memberMapOps,
firOpBuilder.create2DI64ArrayAttr(memberPlacementIndices), parentMapFlag,
captureKind, declareOp.getType(0),
/*partialMap=*/true);

clauseMapVars.emplace_back(mapOp);
mlir::omp::DeclareMapperInfoOp::create(firOpBuilder, loc, clauseMapVars);
return mlir::FlatSymbolRefAttr::get(&converter.getMLIRContext(),
mapperNameStr);
}

bool requiresImplicitDefaultDeclareMapper(
const semantics::DerivedTypeSpec &typeSpec) {
// ISO C interoperable types (e.g., c_ptr, c_funptr) must always have implicit
// default mappers available so that OpenMP offloading can correctly map them.
if (semantics::IsIsoCType(&typeSpec))
return true;

llvm::SmallPtrSet<const semantics::DerivedTypeSpec *, 8> visited;

std::function<bool(const semantics::DerivedTypeSpec &)> requiresMapper =
[&](const semantics::DerivedTypeSpec &spec) -> bool {
if (!visited.insert(&spec).second)
return false;

semantics::DirectComponentIterator directComponents{spec};
for (const semantics::Symbol &component : directComponents) {
if (semantics::IsAllocatableOrPointer(component))
return true;

if (const semantics::DeclTypeSpec *declType = component.GetType())
if (const auto *nested = declType->AsDerived())
if (requiresMapper(*nested))
return true;
}
return false;
};

return requiresMapper(typeSpec);
}

int64_t getCollapseValue(const List<Clause> &clauses) {
auto iter = llvm::find_if(clauses, [](const Clause &clause) {
return clause.id == llvm::omp::Clause::OMPC_collapse;
Expand Down Expand Up @@ -537,6 +681,12 @@ void insertChildMapInfoIntoParent(
mapOperands[std::distance(mapSyms.begin(), parentIter)]
.getDefiningOp());

// Once explicit members are attached to a parent map, do not also invoke
// a declare mapper on it, otherwise the mapper would remap the same
// components leading to duplicate mappings at runtime.
if (!indices.second.memberMap.empty() && mapOp.getMapperIdAttr())
mapOp.setMapperIdAttr(nullptr);

// NOTE: To maintain appropriate SSA ordering, we move the parent map
// which will now have references to its children after the last
// of its members to be generated. This is necessary when a user
Expand Down
Loading
Loading