Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions flang/include/flang/Optimizer/Builder/FIRBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,11 @@ class FirOpBuilder : public mlir::OpBuilder, public mlir::OpBuilder::Listener {
llvm::ArrayRef<mlir::Value> lenParams,
bool asTarget = false);

/// Create a two dimensional ArrayAttr containing integer data as
/// IntegerAttrs, effectively: ArrayAttr<ArrayAttr<IntegerAttr>>>.
mlir::ArrayAttr create2DI64ArrayAttr(
llvm::SmallVectorImpl<llvm::SmallVector<int64_t>> &intData);

/// Create a temporary using `fir.alloca`. This function does not hoist.
/// It is the callers responsibility to set the insertion point if
/// hoisting is required.
Expand Down
76 changes: 44 additions & 32 deletions flang/lib/Lower/OpenMP/ClauseProcessor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -889,16 +889,17 @@ void ClauseProcessor::processMapObjects(
lower::StatementContext &stmtCtx, mlir::Location clauseLocation,
const omp::ObjectList &objects,
llvm::omp::OpenMPOffloadMappingFlags mapTypeBits,
std::map<const semantics::Symbol *,
llvm::SmallVector<OmpMapMemberIndicesData>> &parentMemberIndices,
std::map<Object, OmpMapParentAndMemberData> &parentMemberIndices,
llvm::SmallVectorImpl<mlir::Value> &mapVars,
llvm::SmallVectorImpl<const semantics::Symbol *> *mapSyms,
llvm::SmallVectorImpl<mlir::Location> *mapSymLocs,
llvm::SmallVectorImpl<mlir::Type> *mapSymTypes) const {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();

for (const omp::Object &object : objects) {
llvm::SmallVector<mlir::Value> bounds;
std::stringstream asFortran;
std::optional<omp::Object> parentObj;

lower::AddrAndBoundsInfo info =
lower::gatherDataOperandAddrAndBounds<mlir::omp::MapBoundsOp,
Expand All @@ -907,28 +908,46 @@ void ClauseProcessor::processMapObjects(
object.ref(), clauseLocation, asFortran, bounds,
treatIndexAsSection);

mlir::Value baseOp = info.rawInput;
if (object.sym()->owner().IsDerivedType()) {
omp::ObjectList objectList = gatherObjects(object, semaCtx);
assert(!objectList.empty() &&
"could not find parent objects of derived type member");
parentObj = objectList[0];
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: It's not clear to me why the first returned object is guaranteed to be the parent. Perhaps adding a description to gatherObjects would help.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, consider the symbol: "parent%child%member", we pass the whole symbol, and start from the RHS and peel back the symbol layers until we get to the end, "parent" in this case, we then reverse the list, so "parent" is the first member and it goes from LHS to RHS rather than RHS to LHS! Mainly mentioning to help explain a bit, more than happy to add a comment for it though :-)

parentMemberIndices.emplace(parentObj.value(),
OmpMapParentAndMemberData{});

if (isMemberOrParentAllocatableOrPointer(object, semaCtx)) {
llvm::SmallVector<int64_t> indices;
generateMemberPlacementIndices(object, indices, semaCtx);
baseOp = createParentSymAndGenIntermediateMaps(
clauseLocation, converter, semaCtx, stmtCtx, objectList, indices,
parentMemberIndices[parentObj.value()], asFortran.str(),
mapTypeBits);
}
}

// Explicit map captures are captured ByRef by default,
// optimisation passes may alter this to ByCopy or other capture
// types to optimise
mlir::Value baseOp = info.rawInput;
auto location = mlir::NameLoc::get(
mlir::StringAttr::get(firOpBuilder.getContext(), asFortran.str()),
baseOp.getLoc());
mlir::omp::MapInfoOp mapOp = createMapInfoOp(
firOpBuilder, location, baseOp,
/*varPtrPtr=*/mlir::Value{}, asFortran.str(), bounds,
/*members=*/{}, /*membersIndex=*/mlir::DenseIntElementsAttr{},
/*members=*/{}, /*membersIndex=*/mlir::ArrayAttr{},
static_cast<
std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
mapTypeBits),
mlir::omp::VariableCaptureKind::ByRef, baseOp.getType());

if (object.sym()->owner().IsDerivedType()) {
addChildIndexAndMapToParent(object, parentMemberIndices, mapOp, semaCtx);
if (parentObj.has_value()) {
addChildIndexAndMapToParent(
object, parentMemberIndices[parentObj.value()], mapOp, semaCtx);
} else {
mapVars.push_back(mapOp);
if (mapSyms)
mapSyms->push_back(object.sym());
mapSyms->push_back(object.sym());
if (mapSymTypes)
mapSymTypes->push_back(baseOp.getType());
if (mapSymLocs)
Expand All @@ -949,9 +968,7 @@ bool ClauseProcessor::processMap(
llvm::SmallVector<const semantics::Symbol *> localMapSyms;
llvm::SmallVectorImpl<const semantics::Symbol *> *ptrMapSyms =
mapSyms ? mapSyms : &localMapSyms;
std::map<const semantics::Symbol *,
llvm::SmallVector<OmpMapMemberIndicesData>>
parentMemberIndices;
std::map<Object, OmpMapParentAndMemberData> parentMemberIndices;

bool clauseFound = findRepeatableClause<omp::clause::Map>(
[&](const omp::clause::Map &clause, const parser::CharBlock &source) {
Expand Down Expand Up @@ -1003,17 +1020,15 @@ bool ClauseProcessor::processMap(
mapSymLocs, mapSymTypes);
});

insertChildMapInfoIntoParent(converter, parentMemberIndices, result.mapVars,
*ptrMapSyms, mapSymTypes, mapSymLocs);

insertChildMapInfoIntoParent(converter, semaCtx, stmtCtx, parentMemberIndices,
result.mapVars, mapSymTypes, mapSymLocs,
ptrMapSyms);
return clauseFound;
}

bool ClauseProcessor::processMotionClauses(lower::StatementContext &stmtCtx,
mlir::omp::MapClauseOps &result) {
std::map<const semantics::Symbol *,
llvm::SmallVector<OmpMapMemberIndicesData>>
parentMemberIndices;
std::map<Object, OmpMapParentAndMemberData> parentMemberIndices;
llvm::SmallVector<const semantics::Symbol *> mapSymbols;

auto callbackFn = [&](const auto &clause, const parser::CharBlock &source) {
Expand All @@ -1034,9 +1049,9 @@ bool ClauseProcessor::processMotionClauses(lower::StatementContext &stmtCtx,
clauseFound =
findRepeatableClause<omp::clause::From>(callbackFn) || clauseFound;

insertChildMapInfoIntoParent(converter, parentMemberIndices, result.mapVars,
mapSymbols,
/*mapSymTypes=*/nullptr, /*mapSymLocs=*/nullptr);
insertChildMapInfoIntoParent(
converter, semaCtx, stmtCtx, parentMemberIndices, result.mapVars,
/*mapSymTypes=*/nullptr, /*mapSymLocs=*/nullptr, &mapSymbols);
return clauseFound;
}

Expand Down Expand Up @@ -1110,9 +1125,7 @@ bool ClauseProcessor::processUseDeviceAddr(
llvm::SmallVectorImpl<mlir::Type> &useDeviceTypes,
llvm::SmallVectorImpl<mlir::Location> &useDeviceLocs,
llvm::SmallVectorImpl<const semantics::Symbol *> &useDeviceSyms) const {
std::map<const semantics::Symbol *,
llvm::SmallVector<OmpMapMemberIndicesData>>
parentMemberIndices;
std::map<Object, OmpMapParentAndMemberData> parentMemberIndices;
bool clauseFound = findRepeatableClause<omp::clause::UseDeviceAddr>(
[&](const omp::clause::UseDeviceAddr &clause,
const parser::CharBlock &source) {
Expand All @@ -1125,9 +1138,9 @@ bool ClauseProcessor::processUseDeviceAddr(
&useDeviceSyms, &useDeviceLocs, &useDeviceTypes);
});

insertChildMapInfoIntoParent(converter, parentMemberIndices,
result.useDeviceAddrVars, useDeviceSyms,
&useDeviceTypes, &useDeviceLocs);
insertChildMapInfoIntoParent(converter, semaCtx, stmtCtx, parentMemberIndices,
result.useDeviceAddrVars, &useDeviceTypes,
&useDeviceLocs, &useDeviceSyms);
return clauseFound;
}

Expand All @@ -1136,9 +1149,8 @@ bool ClauseProcessor::processUseDevicePtr(
llvm::SmallVectorImpl<mlir::Type> &useDeviceTypes,
llvm::SmallVectorImpl<mlir::Location> &useDeviceLocs,
llvm::SmallVectorImpl<const semantics::Symbol *> &useDeviceSyms) const {
std::map<const semantics::Symbol *,
llvm::SmallVector<OmpMapMemberIndicesData>>
parentMemberIndices;
std::map<Object, OmpMapParentAndMemberData> parentMemberIndices;

bool clauseFound = findRepeatableClause<omp::clause::UseDevicePtr>(
[&](const omp::clause::UseDevicePtr &clause,
const parser::CharBlock &source) {
Expand All @@ -1151,9 +1163,9 @@ bool ClauseProcessor::processUseDevicePtr(
&useDeviceSyms, &useDeviceLocs, &useDeviceTypes);
});

insertChildMapInfoIntoParent(converter, parentMemberIndices,
result.useDevicePtrVars, useDeviceSyms,
&useDeviceTypes, &useDeviceLocs);
insertChildMapInfoIntoParent(converter, semaCtx, stmtCtx, parentMemberIndices,
result.useDevicePtrVars, &useDeviceTypes,
&useDeviceLocs, &useDeviceSyms);
return clauseFound;
}

Expand Down
3 changes: 1 addition & 2 deletions flang/lib/Lower/OpenMP/ClauseProcessor.h
Original file line number Diff line number Diff line change
Expand Up @@ -178,8 +178,7 @@ class ClauseProcessor {
lower::StatementContext &stmtCtx, mlir::Location clauseLocation,
const omp::ObjectList &objects,
llvm::omp::OpenMPOffloadMappingFlags mapTypeBits,
std::map<const semantics::Symbol *,
llvm::SmallVector<OmpMapMemberIndicesData>> &parentMemberIndices,
std::map<Object, OmpMapParentAndMemberData> &parentMemberIndices,
llvm::SmallVectorImpl<mlir::Value> &mapVars,
llvm::SmallVectorImpl<const semantics::Symbol *> *mapSyms,
llvm::SmallVectorImpl<mlir::Location> *mapSymLocs = nullptr,
Expand Down
11 changes: 11 additions & 0 deletions flang/lib/Lower/OpenMP/Clauses.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,13 @@ struct IdTyTemplate {
return designator == other.designator;
}

// Defining an "ordering" which allows types derived from this to be
// utilised in maps and other containers that require comparison
// operators for ordering
bool operator<(const IdTyTemplate &other) const {
return symbol < other.symbol;
}

operator bool() const { return symbol != nullptr; }
};

Expand All @@ -76,6 +83,10 @@ struct ObjectT<Fortran::lower::omp::IdTyTemplate<Fortran::lower::omp::ExprTy>,
Fortran::semantics::Symbol *sym() const { return identity.symbol; }
const std::optional<ExprTy> &ref() const { return identity.designator; }

bool operator<(const ObjectT<IdTy, ExprTy> &other) const {
return identity < other.identity;
}

IdTy identity;
};
} // namespace tomp::type
Expand Down
4 changes: 2 additions & 2 deletions flang/lib/Lower/OpenMP/OpenMP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -935,7 +935,7 @@ static void genBodyOfTargetOp(
firOpBuilder, copyVal.getLoc(), copyVal,
/*varPtrPtr=*/mlir::Value{}, name.str(), bounds,
/*members=*/llvm::SmallVector<mlir::Value>{},
/*membersIndex=*/mlir::DenseIntElementsAttr{},
/*membersIndex=*/mlir::ArrayAttr{},
static_cast<
std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT),
Expand Down Expand Up @@ -1792,7 +1792,7 @@ genTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
mlir::Value mapOp = createMapInfoOp(
firOpBuilder, location, baseOp, /*varPtrPtr=*/mlir::Value{},
name.str(), bounds, /*members=*/{},
/*membersIndex=*/mlir::DenseIntElementsAttr{},
/*membersIndex=*/mlir::ArrayAttr{},
static_cast<
std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
mapFlag),
Expand Down
Loading
Loading
You are viewing a condensed version of this merge commit. You can view the full changes here.