Skip to content

Commit 32c5784

Browse files
authored
[Flang][OpenMP] Adjust implicit map scalar capture to align with explicit firstprivate (llvm#2986)
2 parents cdc6737 + a340895 commit 32c5784

File tree

14 files changed

+355
-70
lines changed

14 files changed

+355
-70
lines changed

flang/lib/Lower/OpenMP/DataSharingProcessor.cpp

Lines changed: 44 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -43,16 +43,42 @@ bool DataSharingProcessor::OMPConstructSymbolVisitor::isSymbolDefineBy(
4343
[](const auto &functionParserNode) { return false; }});
4444
}
4545

46+
static bool isConstructWithTopLevelTarget(lower::pft::Evaluation &eval) {
47+
const auto *ompEval = eval.getIf<parser::OpenMPConstruct>();
48+
if (ompEval) {
49+
auto dir = extractOmpDirective(*ompEval);
50+
switch (dir) {
51+
case llvm::omp::Directive::OMPD_target:
52+
case llvm::omp::Directive::OMPD_target_loop:
53+
case llvm::omp::Directive::OMPD_target_parallel_do:
54+
case llvm::omp::Directive::OMPD_target_parallel_do_simd:
55+
case llvm::omp::Directive::OMPD_target_parallel_loop:
56+
case llvm::omp::Directive::OMPD_target_teams_distribute:
57+
case llvm::omp::Directive::OMPD_target_teams_distribute_parallel_do:
58+
case llvm::omp::Directive::OMPD_target_teams_distribute_parallel_do_simd:
59+
case llvm::omp::Directive::OMPD_target_teams_distribute_simd:
60+
case llvm::omp::Directive::OMPD_target_teams_loop:
61+
case llvm::omp::Directive::OMPD_target_simd:
62+
return true;
63+
break;
64+
default:
65+
return false;
66+
break;
67+
}
68+
}
69+
return false;
70+
}
71+
4672
DataSharingProcessor::DataSharingProcessor(
4773
lower::AbstractConverter &converter, semantics::SemanticsContext &semaCtx,
4874
const List<Clause> &clauses, lower::pft::Evaluation &eval,
4975
bool shouldCollectPreDeterminedSymbols, bool useDelayedPrivatization,
50-
lower::SymMap &symTable)
76+
lower::SymMap &symTable, bool isTargetPrivitization)
5177
: converter(converter), semaCtx(semaCtx),
5278
firOpBuilder(converter.getFirOpBuilder()), clauses(clauses), eval(eval),
5379
shouldCollectPreDeterminedSymbols(shouldCollectPreDeterminedSymbols),
5480
useDelayedPrivatization(useDelayedPrivatization), symTable(symTable),
55-
visitor(semaCtx) {
81+
isTargetPrivitization(isTargetPrivitization), visitor(semaCtx) {
5682
eval.visit([&](const auto &functionParserNode) {
5783
parser::Walk(functionParserNode, visitor);
5884
});
@@ -62,17 +88,18 @@ DataSharingProcessor::DataSharingProcessor(lower::AbstractConverter &converter,
6288
semantics::SemanticsContext &semaCtx,
6389
lower::pft::Evaluation &eval,
6490
bool useDelayedPrivatization,
65-
lower::SymMap &symTable)
91+
lower::SymMap &symTable,
92+
bool isTargetPrivitization)
6693
: DataSharingProcessor(converter, semaCtx, {}, eval,
6794
/*shouldCollectPreDeterminedSymols=*/false,
68-
useDelayedPrivatization, symTable) {}
95+
useDelayedPrivatization, symTable,
96+
isTargetPrivitization) {}
6997

7098
void DataSharingProcessor::processStep1() {
7199
collectSymbolsForPrivatization();
72100
collectDefaultSymbols();
73101
collectImplicitSymbols();
74102
collectPreDeterminedSymbols();
75-
76103
}
77104

78105
void DataSharingProcessor::processStep2(
@@ -558,8 +585,19 @@ void DataSharingProcessor::collectSymbols(
558585
};
559586

560587
auto shouldCollectSymbol = [&](const semantics::Symbol *sym) {
561-
if (collectImplicit)
588+
if (collectImplicit) {
589+
// If we're a combined construct with a target region, implicit
590+
// firstprivate captures, should only belong to the target region
591+
// and not be added/captured by later directives. Parallel regions
592+
// will likely want the same captures to be shared and for SIMD it's
593+
// illegal to have firstprivate clauses.
594+
if (isConstructWithTopLevelTarget(eval) && !isTargetPrivitization &&
595+
sym->test(semantics::Symbol::Flag::OmpFirstPrivate)) {
596+
return false;
597+
}
598+
562599
return sym->test(semantics::Symbol::Flag::OmpImplicit);
600+
}
563601

564602
if (collectPreDetermined)
565603
return sym->test(semantics::Symbol::Flag::OmpPreDetermined);

flang/lib/Lower/OpenMP/DataSharingProcessor.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ class DataSharingProcessor {
9393
bool useDelayedPrivatization;
9494
llvm::SmallSet<const semantics::Symbol *, 16> mightHaveReadHostSym;
9595
lower::SymMap &symTable;
96+
bool isTargetPrivitization;
9697
OMPConstructSymbolVisitor visitor;
9798
bool privatizationDone = false;
9899

@@ -131,12 +132,14 @@ class DataSharingProcessor {
131132
const List<Clause> &clauses,
132133
lower::pft::Evaluation &eval,
133134
bool shouldCollectPreDeterminedSymbols,
134-
bool useDelayedPrivatization, lower::SymMap &symTable);
135+
bool useDelayedPrivatization, lower::SymMap &symTable,
136+
bool isTargetPrivitization = false);
135137

136138
DataSharingProcessor(lower::AbstractConverter &converter,
137139
semantics::SemanticsContext &semaCtx,
138140
lower::pft::Evaluation &eval,
139-
bool useDelayedPrivatization, lower::SymMap &symTable);
141+
bool useDelayedPrivatization, lower::SymMap &symTable,
142+
bool isTargetPrivitization = false);
140143

141144
// Privatisation is split into 3 steps:
142145
//

flang/lib/Lower/OpenMP/OpenMP.cpp

Lines changed: 38 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2376,6 +2376,36 @@ genSingleOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
23762376
queue, item, clauseOps);
23772377
}
23782378

2379+
static bool isDuplicateMappedSymbol(
2380+
const semantics::Symbol &sym,
2381+
const llvm::SetVector<const semantics::Symbol *> &privatizedSyms,
2382+
const llvm::SmallVectorImpl<const semantics::Symbol *> &hasDevSyms,
2383+
const llvm::SmallVectorImpl<const semantics::Symbol *> &mappedSyms) {
2384+
llvm::SmallVector<const semantics::Symbol *> concatSyms;
2385+
concatSyms.reserve(privatizedSyms.size() + hasDevSyms.size() +
2386+
mappedSyms.size());
2387+
concatSyms.append(privatizedSyms.begin(), privatizedSyms.end());
2388+
concatSyms.append(hasDevSyms.begin(), hasDevSyms.end());
2389+
concatSyms.append(mappedSyms.begin(), mappedSyms.end());
2390+
2391+
auto checkSymbol = [&](const semantics::Symbol &checkSym) {
2392+
if (llvm::is_contained(concatSyms, &checkSym))
2393+
return true;
2394+
2395+
return std::any_of(concatSyms.begin(), concatSyms.end(),
2396+
[&](auto v) { return v->GetUltimate() == checkSym; });
2397+
};
2398+
2399+
if (checkSymbol(sym))
2400+
return true;
2401+
2402+
const auto *hostAssoc{sym.detailsIf<semantics::HostAssocDetails>()};
2403+
if (hostAssoc && checkSymbol(hostAssoc->symbol()))
2404+
return true;
2405+
2406+
return checkSymbol(sym.GetUltimate());
2407+
}
2408+
23792409
static mlir::omp::TargetOp
23802410
genTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
23812411
lower::StatementContext &stmtCtx,
@@ -2402,7 +2432,8 @@ genTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
24022432
DataSharingProcessor dsp(converter, semaCtx, item->clauses, eval,
24032433
/*shouldCollectPreDeterminedSymbols=*/
24042434
lower::omp::isLastItemInQueue(item, queue),
2405-
/*useDelayedPrivatization=*/true, symTable);
2435+
/*useDelayedPrivatization=*/true, symTable,
2436+
/*isTargetPrivitization=*/true);
24062437
dsp.processStep1();
24072438
dsp.processStep2(&clauseOps);
24082439

@@ -2412,17 +2443,6 @@ genTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
24122443
// attribute clauses (neither data-sharing; e.g. `private`, nor `map`
24132444
// clauses).
24142445
auto captureImplicitMap = [&](const semantics::Symbol &sym) {
2415-
if (dsp.getAllSymbolsToPrivatize().contains(&sym))
2416-
return;
2417-
2418-
// Skip parameters/constants as they do not need to be mapped.
2419-
if (semantics::IsNamedConstant(sym))
2420-
return;
2421-
2422-
// These symbols are mapped individually in processHasDeviceAddr.
2423-
if (llvm::is_contained(hasDeviceAddrSyms, &sym))
2424-
return;
2425-
24262446
// Structure component symbols don't have bindings, and can only be
24272447
// explicitly mapped individually. If a member is captured implicitly
24282448
// we map the entirety of the derived type when we find its symbol.
@@ -2443,7 +2463,12 @@ genTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
24432463
if (!converter.getSymbolAddress(sym))
24442464
return;
24452465

2446-
if (!llvm::is_contained(mapSyms, &sym)) {
2466+
// Skip parameters/constants as they do not need to be mapped.
2467+
if (semantics::IsNamedConstant(sym))
2468+
return;
2469+
2470+
if (!isDuplicateMappedSymbol(sym, dsp.getAllSymbolsToPrivatize(),
2471+
hasDeviceAddrSyms, mapSyms)) {
24472472
if (const auto *details =
24482473
sym.template detailsIf<semantics::HostAssocDetails>())
24492474
converter.copySymbolBinding(details->symbol(), sym);

flang/lib/Lower/Support/Utils.cpp

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,12 @@
1212

1313
#include "flang/Lower/Support/Utils.h"
1414

15+
#include "flang/Common/idioms.h"
1516
#include "flang/Common/indirection.h"
1617
#include "flang/Lower/AbstractConverter.h"
1718
#include "flang/Lower/ConvertVariable.h"
1819
#include "flang/Lower/IterationSpace.h"
20+
#include "flang/Lower/OpenMP/Utils.h"
1921
#include "flang/Lower/Support/PrivateReductionUtils.h"
2022
#include "flang/Optimizer/Builder/HLFIRTools.h"
2123
#include "flang/Optimizer/Builder/Todo.h"
@@ -26,6 +28,90 @@
2628
#include <optional>
2729
#include <type_traits>
2830

31+
/// Get the directive enumeration value corresponding to the given OpenMP
32+
/// construct PFT node.
33+
llvm::omp::Directive
34+
extractOmpDirective(const Fortran::parser::OpenMPConstruct &ompConstruct) {
35+
return Fortran::common::visit(
36+
Fortran::common::visitors{
37+
[](const Fortran::parser::OpenMPAllocatorsConstruct &c) {
38+
return llvm::omp::OMPD_allocators;
39+
},
40+
[](const Fortran::parser::OpenMPAssumeConstruct &c) {
41+
return llvm::omp::OMPD_assume;
42+
},
43+
[](const Fortran::parser::OpenMPAtomicConstruct &c) {
44+
return llvm::omp::OMPD_atomic;
45+
},
46+
[](const Fortran::parser::OpenMPBlockConstruct &c) {
47+
return std::get<Fortran::parser::OmpBlockDirective>(
48+
std::get<Fortran::parser::OmpBeginBlockDirective>(c.t).t)
49+
.v;
50+
},
51+
[](const Fortran::parser::OpenMPCriticalConstruct &c) {
52+
return llvm::omp::OMPD_critical;
53+
},
54+
[](const Fortran::parser::OpenMPDeclarativeAllocate &c) {
55+
return llvm::omp::OMPD_allocate;
56+
},
57+
[](const Fortran::parser::OpenMPDispatchConstruct &c) {
58+
return llvm::omp::OMPD_dispatch;
59+
},
60+
[](const Fortran::parser::OpenMPExecutableAllocate &c) {
61+
return llvm::omp::OMPD_allocate;
62+
},
63+
[](const Fortran::parser::OpenMPLoopConstruct &c) {
64+
return std::get<Fortran::parser::OmpLoopDirective>(
65+
std::get<Fortran::parser::OmpBeginLoopDirective>(c.t).t)
66+
.v;
67+
},
68+
[](const Fortran::parser::OpenMPSectionConstruct &c) {
69+
return llvm::omp::OMPD_section;
70+
},
71+
[](const Fortran::parser::OpenMPSectionsConstruct &c) {
72+
return std::get<Fortran::parser::OmpSectionsDirective>(
73+
std::get<Fortran::parser::OmpBeginSectionsDirective>(c.t)
74+
.t)
75+
.v;
76+
},
77+
[](const Fortran::parser::OpenMPStandaloneConstruct &c) {
78+
return Fortran::common::visit(
79+
Fortran::common::visitors{
80+
[](const Fortran::parser::OpenMPSimpleStandaloneConstruct
81+
&c) { return c.v.DirId(); },
82+
[](const Fortran::parser::OpenMPFlushConstruct &c) {
83+
return llvm::omp::OMPD_flush;
84+
},
85+
[](const Fortran::parser::OpenMPCancelConstruct &c) {
86+
return llvm::omp::OMPD_cancel;
87+
},
88+
[](const Fortran::parser::OpenMPCancellationPointConstruct
89+
&c) { return llvm::omp::OMPD_cancellation_point; },
90+
[](const Fortran::parser::OmpMetadirectiveDirective &c) {
91+
return llvm::omp::OMPD_metadirective;
92+
},
93+
[](const Fortran::parser::OpenMPDepobjConstruct &c) {
94+
return llvm::omp::OMPD_depobj;
95+
},
96+
[](const Fortran::parser::OpenMPInteropConstruct &c) {
97+
return llvm::omp::OMPD_interop;
98+
}},
99+
c.u);
100+
},
101+
[](const Fortran::parser::OpenMPUtilityConstruct &c) {
102+
return Fortran::common::visit(
103+
Fortran::common::visitors{
104+
[](const Fortran::parser::OmpErrorDirective &c) {
105+
return llvm::omp::OMPD_error;
106+
},
107+
[](const Fortran::parser::OmpNothingDirective &c) {
108+
return llvm::omp::OMPD_nothing;
109+
}},
110+
c.u);
111+
}},
112+
ompConstruct.u);
113+
}
114+
29115
namespace Fortran::lower {
30116
// Fortran::evaluate::Expr are functional values organized like an AST. A
31117
// Fortran::evaluate::Expr is meant to be moved and cloned. Using the front end

flang/lib/Optimizer/OpenMP/MapsForPrivatizedSymbols.cpp

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,20 @@ class MapsForPrivatizedSymbolsPass
5555

5656
omp::MapInfoOp createMapInfo(Location loc, Value var,
5757
fir::FirOpBuilder &builder) {
58+
// Check if a value of type `type` can be passed to the kernel by value.
59+
// All kernel parameters are of pointer type, so if the value can be
60+
// represented inside of a pointer, then it can be passed by value.
61+
auto isLiteralType = [&](mlir::Type type) {
62+
const mlir::DataLayout &dl = builder.getDataLayout();
63+
mlir::Type ptrTy = mlir::LLVM::LLVMPointerType::get(builder.getContext());
64+
uint64_t ptrSize = dl.getTypeSize(ptrTy);
65+
uint64_t ptrAlign = dl.getTypePreferredAlignment(ptrTy);
66+
67+
auto [size, align] = fir::getTypeSizeAndAlignmentOrCrash(
68+
loc, type, dl, builder.getKindMap());
69+
return size <= ptrSize && align <= ptrAlign;
70+
};
71+
5872
uint64_t mapTypeTo = static_cast<
5973
std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
6074
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO);
@@ -94,14 +108,22 @@ class MapsForPrivatizedSymbolsPass
94108
if (needsBoundsOps(varPtr))
95109
genBoundsOps(builder, varPtr, boundsOps);
96110

111+
mlir::omp::VariableCaptureKind captureKind =
112+
mlir::omp::VariableCaptureKind::ByRef;
113+
if (fir::isa_trivial(fir::unwrapRefType(varPtr.getType())) ||
114+
fir::isa_char(fir::unwrapRefType(varPtr.getType()))) {
115+
if (isLiteralType(fir::unwrapRefType(varPtr.getType()))) {
116+
captureKind = mlir::omp::VariableCaptureKind::ByCopy;
117+
}
118+
}
119+
97120
return builder.create<omp::MapInfoOp>(
98121
loc, varPtr.getType(), varPtr,
99122
TypeAttr::get(llvm::cast<omp::PointerLikeType>(varPtr.getType())
100123
.getElementType()),
101124
builder.getIntegerAttr(builder.getIntegerType(64, /*isSigned=*/false),
102125
mapTypeTo),
103-
builder.getAttr<omp::VariableCaptureKindAttr>(
104-
omp::VariableCaptureKind::ByRef),
126+
builder.getAttr<omp::VariableCaptureKindAttr>(captureKind),
105127
/*varPtrPtr=*/Value{},
106128
/*members=*/SmallVector<Value>{},
107129
/*member_index=*/mlir::ArrayAttr{},

0 commit comments

Comments
 (0)