Skip to content

Commit a340895

Browse files
committed
[Flang][OpenMP][WIP] Adjust implicit map scalar capture to align with explicit firstprivate
Currently the implicit map capture of scalars is a little off from the explicit firstpriviate clause generation, the runtime will indicate it is firstprivate, but the code generation is different enough that it prevents it from truly being treated as firstprivate. So this patch tries to align these a little better aligning the code generation for implicit with explicit firstprivate. NOTE: This shouldn't be merged at the moment, it's for testing, if we're happy with it I'll have to think about how we'll align downstream with upstream without posing to many conflicts with tests. As we'll want this on by default with no extra flags required downstream I imagine, but upstream we have a flag (commented out in this) to enable firstprivate in certain cases (we have it commented out here), that if not enabled will cause semantic errors if firstprivate is utilised, and the implicit scalars will trigger it if left enabled when the flag is disabled, so the implicit behaviour also needs to be sheltered behind the flag for now.
1 parent 3c73b57 commit a340895

File tree

14 files changed

+355
-70
lines changed

14 files changed

+355
-70
lines changed

flang/lib/Lower/OpenMP/DataSharingProcessor.cpp

Lines changed: 44 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -42,16 +42,42 @@ bool DataSharingProcessor::OMPConstructSymbolVisitor::isSymbolDefineBy(
4242
[](const auto &functionParserNode) { return false; }});
4343
}
4444

45+
static bool isConstructWithTopLevelTarget(lower::pft::Evaluation &eval) {
46+
const auto *ompEval = eval.getIf<parser::OpenMPConstruct>();
47+
if (ompEval) {
48+
auto dir = extractOmpDirective(*ompEval);
49+
switch (dir) {
50+
case llvm::omp::Directive::OMPD_target:
51+
case llvm::omp::Directive::OMPD_target_loop:
52+
case llvm::omp::Directive::OMPD_target_parallel_do:
53+
case llvm::omp::Directive::OMPD_target_parallel_do_simd:
54+
case llvm::omp::Directive::OMPD_target_parallel_loop:
55+
case llvm::omp::Directive::OMPD_target_teams_distribute:
56+
case llvm::omp::Directive::OMPD_target_teams_distribute_parallel_do:
57+
case llvm::omp::Directive::OMPD_target_teams_distribute_parallel_do_simd:
58+
case llvm::omp::Directive::OMPD_target_teams_distribute_simd:
59+
case llvm::omp::Directive::OMPD_target_teams_loop:
60+
case llvm::omp::Directive::OMPD_target_simd:
61+
return true;
62+
break;
63+
default:
64+
return false;
65+
break;
66+
}
67+
}
68+
return false;
69+
}
70+
4571
DataSharingProcessor::DataSharingProcessor(
4672
lower::AbstractConverter &converter, semantics::SemanticsContext &semaCtx,
4773
const List<Clause> &clauses, lower::pft::Evaluation &eval,
4874
bool shouldCollectPreDeterminedSymbols, bool useDelayedPrivatization,
49-
lower::SymMap &symTable)
75+
lower::SymMap &symTable, bool isTargetPrivitization)
5076
: converter(converter), semaCtx(semaCtx),
5177
firOpBuilder(converter.getFirOpBuilder()), clauses(clauses), eval(eval),
5278
shouldCollectPreDeterminedSymbols(shouldCollectPreDeterminedSymbols),
5379
useDelayedPrivatization(useDelayedPrivatization), symTable(symTable),
54-
visitor(semaCtx) {
80+
isTargetPrivitization(isTargetPrivitization), visitor(semaCtx) {
5581
eval.visit([&](const auto &functionParserNode) {
5682
parser::Walk(functionParserNode, visitor);
5783
});
@@ -61,17 +87,18 @@ DataSharingProcessor::DataSharingProcessor(lower::AbstractConverter &converter,
6187
semantics::SemanticsContext &semaCtx,
6288
lower::pft::Evaluation &eval,
6389
bool useDelayedPrivatization,
64-
lower::SymMap &symTable)
90+
lower::SymMap &symTable,
91+
bool isTargetPrivitization)
6592
: DataSharingProcessor(converter, semaCtx, {}, eval,
6693
/*shouldCollectPreDeterminedSymols=*/false,
67-
useDelayedPrivatization, symTable) {}
94+
useDelayedPrivatization, symTable,
95+
isTargetPrivitization) {}
6896

6997
void DataSharingProcessor::processStep1() {
7098
collectSymbolsForPrivatization();
7199
collectDefaultSymbols();
72100
collectImplicitSymbols();
73101
collectPreDeterminedSymbols();
74-
75102
}
76103

77104
void DataSharingProcessor::processStep2(
@@ -556,8 +583,19 @@ void DataSharingProcessor::collectSymbols(
556583
};
557584

558585
auto shouldCollectSymbol = [&](const semantics::Symbol *sym) {
559-
if (collectImplicit)
586+
if (collectImplicit) {
587+
// If we're a combined construct with a target region, implicit
588+
// firstprivate captures, should only belong to the target region
589+
// and not be added/captured by later directives. Parallel regions
590+
// will likely want the same captures to be shared and for SIMD it's
591+
// illegal to have firstprivate clauses.
592+
if (isConstructWithTopLevelTarget(eval) && !isTargetPrivitization &&
593+
sym->test(semantics::Symbol::Flag::OmpFirstPrivate)) {
594+
return false;
595+
}
596+
560597
return sym->test(semantics::Symbol::Flag::OmpImplicit);
598+
}
561599

562600
if (collectPreDetermined)
563601
return sym->test(semantics::Symbol::Flag::OmpPreDetermined);

flang/lib/Lower/OpenMP/DataSharingProcessor.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ class DataSharingProcessor {
9393
bool useDelayedPrivatization;
9494
llvm::SmallSet<const semantics::Symbol *, 16> mightHaveReadHostSym;
9595
lower::SymMap &symTable;
96+
bool isTargetPrivitization;
9697
OMPConstructSymbolVisitor visitor;
9798
bool privatizationDone = false;
9899

@@ -131,12 +132,14 @@ class DataSharingProcessor {
131132
const List<Clause> &clauses,
132133
lower::pft::Evaluation &eval,
133134
bool shouldCollectPreDeterminedSymbols,
134-
bool useDelayedPrivatization, lower::SymMap &symTable);
135+
bool useDelayedPrivatization, lower::SymMap &symTable,
136+
bool isTargetPrivitization = false);
135137

136138
DataSharingProcessor(lower::AbstractConverter &converter,
137139
semantics::SemanticsContext &semaCtx,
138140
lower::pft::Evaluation &eval,
139-
bool useDelayedPrivatization, lower::SymMap &symTable);
141+
bool useDelayedPrivatization, lower::SymMap &symTable,
142+
bool isTargetPrivitization = false);
140143

141144
// Privatisation is split into 3 steps:
142145
//

flang/lib/Lower/OpenMP/OpenMP.cpp

Lines changed: 38 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2396,6 +2396,36 @@ genSingleOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
23962396
queue, item, clauseOps);
23972397
}
23982398

2399+
static bool isDuplicateMappedSymbol(
2400+
const semantics::Symbol &sym,
2401+
const llvm::SetVector<const semantics::Symbol *> &privatizedSyms,
2402+
const llvm::SmallVectorImpl<const semantics::Symbol *> &hasDevSyms,
2403+
const llvm::SmallVectorImpl<const semantics::Symbol *> &mappedSyms) {
2404+
llvm::SmallVector<const semantics::Symbol *> concatSyms;
2405+
concatSyms.reserve(privatizedSyms.size() + hasDevSyms.size() +
2406+
mappedSyms.size());
2407+
concatSyms.append(privatizedSyms.begin(), privatizedSyms.end());
2408+
concatSyms.append(hasDevSyms.begin(), hasDevSyms.end());
2409+
concatSyms.append(mappedSyms.begin(), mappedSyms.end());
2410+
2411+
auto checkSymbol = [&](const semantics::Symbol &checkSym) {
2412+
if (llvm::is_contained(concatSyms, &checkSym))
2413+
return true;
2414+
2415+
return std::any_of(concatSyms.begin(), concatSyms.end(),
2416+
[&](auto v) { return v->GetUltimate() == checkSym; });
2417+
};
2418+
2419+
if (checkSymbol(sym))
2420+
return true;
2421+
2422+
const auto *hostAssoc{sym.detailsIf<semantics::HostAssocDetails>()};
2423+
if (hostAssoc && checkSymbol(hostAssoc->symbol()))
2424+
return true;
2425+
2426+
return checkSymbol(sym.GetUltimate());
2427+
}
2428+
23992429
static mlir::omp::TargetOp
24002430
genTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
24012431
lower::StatementContext &stmtCtx,
@@ -2422,7 +2452,8 @@ genTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
24222452
DataSharingProcessor dsp(converter, semaCtx, item->clauses, eval,
24232453
/*shouldCollectPreDeterminedSymbols=*/
24242454
lower::omp::isLastItemInQueue(item, queue),
2425-
/*useDelayedPrivatization=*/true, symTable);
2455+
/*useDelayedPrivatization=*/true, symTable,
2456+
/*isTargetPrivitization=*/true);
24262457
dsp.processStep1();
24272458
dsp.processStep2(&clauseOps);
24282459

@@ -2432,17 +2463,6 @@ genTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
24322463
// attribute clauses (neither data-sharing; e.g. `private`, nor `map`
24332464
// clauses).
24342465
auto captureImplicitMap = [&](const semantics::Symbol &sym) {
2435-
if (dsp.getAllSymbolsToPrivatize().contains(&sym))
2436-
return;
2437-
2438-
// Skip parameters/constants as they do not need to be mapped.
2439-
if (semantics::IsNamedConstant(sym))
2440-
return;
2441-
2442-
// These symbols are mapped individually in processHasDeviceAddr.
2443-
if (llvm::is_contained(hasDeviceAddrSyms, &sym))
2444-
return;
2445-
24462466
// Structure component symbols don't have bindings, and can only be
24472467
// explicitly mapped individually. If a member is captured implicitly
24482468
// we map the entirety of the derived type when we find its symbol.
@@ -2463,7 +2483,12 @@ genTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
24632483
if (!converter.getSymbolAddress(sym))
24642484
return;
24652485

2466-
if (!llvm::is_contained(mapSyms, &sym)) {
2486+
// Skip parameters/constants as they do not need to be mapped.
2487+
if (semantics::IsNamedConstant(sym))
2488+
return;
2489+
2490+
if (!isDuplicateMappedSymbol(sym, dsp.getAllSymbolsToPrivatize(),
2491+
hasDeviceAddrSyms, mapSyms)) {
24672492
if (const auto *details =
24682493
sym.template detailsIf<semantics::HostAssocDetails>())
24692494
converter.copySymbolBinding(details->symbol(), sym);

flang/lib/Lower/Support/Utils.cpp

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,12 @@
1212

1313
#include "flang/Lower/Support/Utils.h"
1414

15+
#include "flang/Common/idioms.h"
1516
#include "flang/Common/indirection.h"
1617
#include "flang/Lower/AbstractConverter.h"
1718
#include "flang/Lower/ConvertVariable.h"
1819
#include "flang/Lower/IterationSpace.h"
20+
#include "flang/Lower/OpenMP/Utils.h"
1921
#include "flang/Lower/Support/PrivateReductionUtils.h"
2022
#include "flang/Optimizer/Builder/HLFIRTools.h"
2123
#include "flang/Optimizer/Builder/Todo.h"
@@ -26,6 +28,90 @@
2628
#include <optional>
2729
#include <type_traits>
2830

31+
/// Get the directive enumeration value corresponding to the given OpenMP
32+
/// construct PFT node.
33+
llvm::omp::Directive
34+
extractOmpDirective(const Fortran::parser::OpenMPConstruct &ompConstruct) {
35+
return Fortran::common::visit(
36+
Fortran::common::visitors{
37+
[](const Fortran::parser::OpenMPAllocatorsConstruct &c) {
38+
return llvm::omp::OMPD_allocators;
39+
},
40+
[](const Fortran::parser::OpenMPAssumeConstruct &c) {
41+
return llvm::omp::OMPD_assume;
42+
},
43+
[](const Fortran::parser::OpenMPAtomicConstruct &c) {
44+
return llvm::omp::OMPD_atomic;
45+
},
46+
[](const Fortran::parser::OpenMPBlockConstruct &c) {
47+
return std::get<Fortran::parser::OmpBlockDirective>(
48+
std::get<Fortran::parser::OmpBeginBlockDirective>(c.t).t)
49+
.v;
50+
},
51+
[](const Fortran::parser::OpenMPCriticalConstruct &c) {
52+
return llvm::omp::OMPD_critical;
53+
},
54+
[](const Fortran::parser::OpenMPDeclarativeAllocate &c) {
55+
return llvm::omp::OMPD_allocate;
56+
},
57+
[](const Fortran::parser::OpenMPDispatchConstruct &c) {
58+
return llvm::omp::OMPD_dispatch;
59+
},
60+
[](const Fortran::parser::OpenMPExecutableAllocate &c) {
61+
return llvm::omp::OMPD_allocate;
62+
},
63+
[](const Fortran::parser::OpenMPLoopConstruct &c) {
64+
return std::get<Fortran::parser::OmpLoopDirective>(
65+
std::get<Fortran::parser::OmpBeginLoopDirective>(c.t).t)
66+
.v;
67+
},
68+
[](const Fortran::parser::OpenMPSectionConstruct &c) {
69+
return llvm::omp::OMPD_section;
70+
},
71+
[](const Fortran::parser::OpenMPSectionsConstruct &c) {
72+
return std::get<Fortran::parser::OmpSectionsDirective>(
73+
std::get<Fortran::parser::OmpBeginSectionsDirective>(c.t)
74+
.t)
75+
.v;
76+
},
77+
[](const Fortran::parser::OpenMPStandaloneConstruct &c) {
78+
return Fortran::common::visit(
79+
Fortran::common::visitors{
80+
[](const Fortran::parser::OpenMPSimpleStandaloneConstruct
81+
&c) { return c.v.DirId(); },
82+
[](const Fortran::parser::OpenMPFlushConstruct &c) {
83+
return llvm::omp::OMPD_flush;
84+
},
85+
[](const Fortran::parser::OpenMPCancelConstruct &c) {
86+
return llvm::omp::OMPD_cancel;
87+
},
88+
[](const Fortran::parser::OpenMPCancellationPointConstruct
89+
&c) { return llvm::omp::OMPD_cancellation_point; },
90+
[](const Fortran::parser::OmpMetadirectiveDirective &c) {
91+
return llvm::omp::OMPD_metadirective;
92+
},
93+
[](const Fortran::parser::OpenMPDepobjConstruct &c) {
94+
return llvm::omp::OMPD_depobj;
95+
},
96+
[](const Fortran::parser::OpenMPInteropConstruct &c) {
97+
return llvm::omp::OMPD_interop;
98+
}},
99+
c.u);
100+
},
101+
[](const Fortran::parser::OpenMPUtilityConstruct &c) {
102+
return Fortran::common::visit(
103+
Fortran::common::visitors{
104+
[](const Fortran::parser::OmpErrorDirective &c) {
105+
return llvm::omp::OMPD_error;
106+
},
107+
[](const Fortran::parser::OmpNothingDirective &c) {
108+
return llvm::omp::OMPD_nothing;
109+
}},
110+
c.u);
111+
}},
112+
ompConstruct.u);
113+
}
114+
29115
namespace Fortran::lower {
30116
// Fortran::evaluate::Expr are functional values organized like an AST. A
31117
// Fortran::evaluate::Expr is meant to be moved and cloned. Using the front end

flang/lib/Optimizer/OpenMP/MapsForPrivatizedSymbols.cpp

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,20 @@ class MapsForPrivatizedSymbolsPass
5555

5656
omp::MapInfoOp createMapInfo(Location loc, Value var,
5757
fir::FirOpBuilder &builder) {
58+
// Check if a value of type `type` can be passed to the kernel by value.
59+
// All kernel parameters are of pointer type, so if the value can be
60+
// represented inside of a pointer, then it can be passed by value.
61+
auto isLiteralType = [&](mlir::Type type) {
62+
const mlir::DataLayout &dl = builder.getDataLayout();
63+
mlir::Type ptrTy = mlir::LLVM::LLVMPointerType::get(builder.getContext());
64+
uint64_t ptrSize = dl.getTypeSize(ptrTy);
65+
uint64_t ptrAlign = dl.getTypePreferredAlignment(ptrTy);
66+
67+
auto [size, align] = fir::getTypeSizeAndAlignmentOrCrash(
68+
loc, type, dl, builder.getKindMap());
69+
return size <= ptrSize && align <= ptrAlign;
70+
};
71+
5872
uint64_t mapTypeTo = static_cast<
5973
std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
6074
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO);
@@ -94,14 +108,22 @@ class MapsForPrivatizedSymbolsPass
94108
if (needsBoundsOps(varPtr))
95109
genBoundsOps(builder, varPtr, boundsOps);
96110

111+
mlir::omp::VariableCaptureKind captureKind =
112+
mlir::omp::VariableCaptureKind::ByRef;
113+
if (fir::isa_trivial(fir::unwrapRefType(varPtr.getType())) ||
114+
fir::isa_char(fir::unwrapRefType(varPtr.getType()))) {
115+
if (isLiteralType(fir::unwrapRefType(varPtr.getType()))) {
116+
captureKind = mlir::omp::VariableCaptureKind::ByCopy;
117+
}
118+
}
119+
97120
return builder.create<omp::MapInfoOp>(
98121
loc, varPtr.getType(), varPtr,
99122
TypeAttr::get(llvm::cast<omp::PointerLikeType>(varPtr.getType())
100123
.getElementType()),
101124
builder.getIntegerAttr(builder.getIntegerType(64, /*isSigned=*/false),
102125
mapTypeTo),
103-
builder.getAttr<omp::VariableCaptureKindAttr>(
104-
omp::VariableCaptureKind::ByRef),
126+
builder.getAttr<omp::VariableCaptureKindAttr>(captureKind),
105127
/*varPtrPtr=*/Value{},
106128
/*members=*/SmallVector<Value>{},
107129
/*member_index=*/mlir::ArrayAttr{},

0 commit comments

Comments
 (0)