Skip to content

Commit 3b10b9a

Browse files
authored
[MLIR][OpenMP] Add lowering support for AUTOMAP modifier (#151513)
Add Automap modifier to the MLIR op definition for the DeclareTarget directive's Enter clause. Also add lowering support in Flang. Automap Ref: OpenMP 6.0 section 7.9.7.
1 parent 13cd725 commit 3b10b9a

20 files changed

+195
-160
lines changed

flang/include/flang/Lower/OpenMP.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ struct Variable;
5757
struct OMPDeferredDeclareTargetInfo {
5858
mlir::omp::DeclareTargetCaptureClause declareTargetCaptureClause;
5959
mlir::omp::DeclareTargetDeviceType declareTargetDeviceType;
60+
bool automap = false;
6061
const Fortran::semantics::Symbol &sym;
6162
};
6263

flang/lib/Lower/OpenMP/ClauseProcessor.cpp

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1179,12 +1179,13 @@ bool ClauseProcessor::processLinear(mlir::omp::LinearClauseOps &result) const {
11791179
}
11801180

11811181
bool ClauseProcessor::processLink(
1182-
llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const {
1182+
llvm::SmallVectorImpl<DeclareTargetCaptureInfo> &result) const {
11831183
return findRepeatableClause<omp::clause::Link>(
11841184
[&](const omp::clause::Link &clause, const parser::CharBlock &) {
11851185
// Case: declare target link(var1, var2)...
11861186
gatherFuncAndVarSyms(
1187-
clause.v, mlir::omp::DeclareTargetCaptureClause::link, result);
1187+
clause.v, mlir::omp::DeclareTargetCaptureClause::link, result,
1188+
/*automap=*/false);
11881189
});
11891190
}
11901191

@@ -1507,26 +1508,28 @@ bool ClauseProcessor::processTaskReduction(
15071508
}
15081509

15091510
bool ClauseProcessor::processTo(
1510-
llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const {
1511+
llvm::SmallVectorImpl<DeclareTargetCaptureInfo> &result) const {
15111512
return findRepeatableClause<omp::clause::To>(
15121513
[&](const omp::clause::To &clause, const parser::CharBlock &) {
15131514
// Case: declare target to(func, var1, var2)...
15141515
gatherFuncAndVarSyms(std::get<ObjectList>(clause.t),
1515-
mlir::omp::DeclareTargetCaptureClause::to, result);
1516+
mlir::omp::DeclareTargetCaptureClause::to, result,
1517+
/*automap=*/false);
15161518
});
15171519
}
15181520

15191521
bool ClauseProcessor::processEnter(
1520-
llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const {
1522+
llvm::SmallVectorImpl<DeclareTargetCaptureInfo> &result) const {
15211523
return findRepeatableClause<omp::clause::Enter>(
15221524
[&](const omp::clause::Enter &clause, const parser::CharBlock &source) {
15231525
mlir::Location currentLocation = converter.genLocation(source);
1524-
if (std::get<std::optional<omp::clause::Enter::Modifier>>(clause.t))
1525-
TODO(currentLocation, "Declare target enter AUTOMAP modifier");
1526+
bool automap =
1527+
std::get<std::optional<omp::clause::Enter::Modifier>>(clause.t)
1528+
.has_value();
15261529
// Case: declare target enter(func, var1, var2)...
15271530
gatherFuncAndVarSyms(std::get<ObjectList>(clause.t),
15281531
mlir::omp::DeclareTargetCaptureClause::enter,
1529-
result);
1532+
result, automap);
15301533
});
15311534
}
15321535

flang/lib/Lower/OpenMP/ClauseProcessor.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ class ClauseProcessor {
118118
bool processDepend(lower::SymMap &symMap, lower::StatementContext &stmtCtx,
119119
mlir::omp::DependClauseOps &result) const;
120120
bool
121-
processEnter(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const;
121+
processEnter(llvm::SmallVectorImpl<DeclareTargetCaptureInfo> &result) const;
122122
bool processIf(omp::clause::If::DirectiveNameModifier directiveName,
123123
mlir::omp::IfClauseOps &result) const;
124124
bool processInReduction(
@@ -129,7 +129,7 @@ class ClauseProcessor {
129129
llvm::SmallVectorImpl<const semantics::Symbol *> &isDeviceSyms) const;
130130
bool processLinear(mlir::omp::LinearClauseOps &result) const;
131131
bool
132-
processLink(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const;
132+
processLink(llvm::SmallVectorImpl<DeclareTargetCaptureInfo> &result) const;
133133

134134
// This method is used to process a map clause.
135135
// The optional parameter mapSyms is used to store the original Fortran symbol
@@ -150,7 +150,7 @@ class ClauseProcessor {
150150
bool processTaskReduction(
151151
mlir::Location currentLocation, mlir::omp::TaskReductionClauseOps &result,
152152
llvm::SmallVectorImpl<const semantics::Symbol *> &outReductionSyms) const;
153-
bool processTo(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const;
153+
bool processTo(llvm::SmallVectorImpl<DeclareTargetCaptureInfo> &result) const;
154154
bool processUseDeviceAddr(
155155
lower::StatementContext &stmtCtx,
156156
mlir::omp::UseDeviceAddrClauseOps &result,

flang/lib/Lower/OpenMP/OpenMP.cpp

Lines changed: 22 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -765,14 +765,14 @@ static void getDeclareTargetInfo(
765765
lower::pft::Evaluation &eval,
766766
const parser::OpenMPDeclareTargetConstruct &declareTargetConstruct,
767767
mlir::omp::DeclareTargetOperands &clauseOps,
768-
llvm::SmallVectorImpl<DeclareTargetCapturePair> &symbolAndClause) {
768+
llvm::SmallVectorImpl<DeclareTargetCaptureInfo> &symbolAndClause) {
769769
const auto &spec =
770770
std::get<parser::OmpDeclareTargetSpecifier>(declareTargetConstruct.t);
771771
if (const auto *objectList{parser::Unwrap<parser::OmpObjectList>(spec.u)}) {
772772
ObjectList objects{makeObjects(*objectList, semaCtx)};
773773
// Case: declare target(func, var1, var2)
774774
gatherFuncAndVarSyms(objects, mlir::omp::DeclareTargetCaptureClause::to,
775-
symbolAndClause);
775+
symbolAndClause, /*automap=*/false);
776776
} else if (const auto *clauseList{
777777
parser::Unwrap<parser::OmpClauseList>(spec.u)}) {
778778
List<Clause> clauses = makeClauses(*clauseList, semaCtx);
@@ -805,21 +805,20 @@ static void collectDeferredDeclareTargets(
805805
llvm::SmallVectorImpl<lower::OMPDeferredDeclareTargetInfo>
806806
&deferredDeclareTarget) {
807807
mlir::omp::DeclareTargetOperands clauseOps;
808-
llvm::SmallVector<DeclareTargetCapturePair> symbolAndClause;
808+
llvm::SmallVector<DeclareTargetCaptureInfo> symbolAndClause;
809809
getDeclareTargetInfo(converter, semaCtx, eval, declareTargetConstruct,
810810
clauseOps, symbolAndClause);
811811
// Return the device type only if at least one of the targets for the
812812
// directive is a function or subroutine
813813
mlir::ModuleOp mod = converter.getFirOpBuilder().getModule();
814814

815-
for (const DeclareTargetCapturePair &symClause : symbolAndClause) {
816-
mlir::Operation *op = mod.lookupSymbol(
817-
converter.mangleName(std::get<const semantics::Symbol &>(symClause)));
815+
for (const DeclareTargetCaptureInfo &symClause : symbolAndClause) {
816+
mlir::Operation *op =
817+
mod.lookupSymbol(converter.mangleName(symClause.symbol));
818818

819819
if (!op) {
820-
deferredDeclareTarget.push_back({std::get<0>(symClause),
821-
clauseOps.deviceType,
822-
std::get<1>(symClause)});
820+
deferredDeclareTarget.push_back({symClause.clause, clauseOps.deviceType,
821+
symClause.automap, symClause.symbol});
823822
}
824823
}
825824
}
@@ -830,16 +829,16 @@ getDeclareTargetFunctionDevice(
830829
lower::pft::Evaluation &eval,
831830
const parser::OpenMPDeclareTargetConstruct &declareTargetConstruct) {
832831
mlir::omp::DeclareTargetOperands clauseOps;
833-
llvm::SmallVector<DeclareTargetCapturePair> symbolAndClause;
832+
llvm::SmallVector<DeclareTargetCaptureInfo> symbolAndClause;
834833
getDeclareTargetInfo(converter, semaCtx, eval, declareTargetConstruct,
835834
clauseOps, symbolAndClause);
836835

837836
// Return the device type only if at least one of the targets for the
838837
// directive is a function or subroutine
839838
mlir::ModuleOp mod = converter.getFirOpBuilder().getModule();
840-
for (const DeclareTargetCapturePair &symClause : symbolAndClause) {
841-
mlir::Operation *op = mod.lookupSymbol(
842-
converter.mangleName(std::get<const semantics::Symbol &>(symClause)));
839+
for (const DeclareTargetCaptureInfo &symClause : symbolAndClause) {
840+
mlir::Operation *op =
841+
mod.lookupSymbol(converter.mangleName(symClause.symbol));
843842

844843
if (mlir::isa_and_nonnull<mlir::func::FuncOp>(op))
845844
return clauseOps.deviceType;
@@ -1056,7 +1055,7 @@ getImplicitMapTypeAndKind(fir::FirOpBuilder &firOpBuilder,
10561055
static void
10571056
markDeclareTarget(mlir::Operation *op, lower::AbstractConverter &converter,
10581057
mlir::omp::DeclareTargetCaptureClause captureClause,
1059-
mlir::omp::DeclareTargetDeviceType deviceType) {
1058+
mlir::omp::DeclareTargetDeviceType deviceType, bool automap) {
10601059
// TODO: Add support for program local variables with declare target applied
10611060
auto declareTargetOp = llvm::dyn_cast<mlir::omp::DeclareTargetInterface>(op);
10621061
if (!declareTargetOp)
@@ -1071,11 +1070,11 @@ markDeclareTarget(mlir::Operation *op, lower::AbstractConverter &converter,
10711070
if (declareTargetOp.isDeclareTarget()) {
10721071
if (declareTargetOp.getDeclareTargetDeviceType() != deviceType)
10731072
declareTargetOp.setDeclareTarget(mlir::omp::DeclareTargetDeviceType::any,
1074-
captureClause);
1073+
captureClause, automap);
10751074
return;
10761075
}
10771076

1078-
declareTargetOp.setDeclareTarget(deviceType, captureClause);
1077+
declareTargetOp.setDeclareTarget(deviceType, captureClause, automap);
10791078
}
10801079

10811080
//===----------------------------------------------------------------------===//
@@ -3564,25 +3563,23 @@ genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable,
35643563
semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval,
35653564
const parser::OpenMPDeclareTargetConstruct &declareTargetConstruct) {
35663565
mlir::omp::DeclareTargetOperands clauseOps;
3567-
llvm::SmallVector<DeclareTargetCapturePair> symbolAndClause;
3566+
llvm::SmallVector<DeclareTargetCaptureInfo> symbolAndClause;
35683567
mlir::ModuleOp mod = converter.getFirOpBuilder().getModule();
35693568
getDeclareTargetInfo(converter, semaCtx, eval, declareTargetConstruct,
35703569
clauseOps, symbolAndClause);
35713570

3572-
for (const DeclareTargetCapturePair &symClause : symbolAndClause) {
3573-
mlir::Operation *op = mod.lookupSymbol(
3574-
converter.mangleName(std::get<const semantics::Symbol &>(symClause)));
3571+
for (const DeclareTargetCaptureInfo &symClause : symbolAndClause) {
3572+
mlir::Operation *op =
3573+
mod.lookupSymbol(converter.mangleName(symClause.symbol));
35753574

35763575
// Some symbols are deferred until later in the module, these are handled
35773576
// upon finalization of the module for OpenMP inside of Bridge, so we simply
35783577
// skip for now.
35793578
if (!op)
35803579
continue;
35813580

3582-
markDeclareTarget(
3583-
op, converter,
3584-
std::get<mlir::omp::DeclareTargetCaptureClause>(symClause),
3585-
clauseOps.deviceType);
3581+
markDeclareTarget(op, converter, symClause.clause, clauseOps.deviceType,
3582+
symClause.automap);
35863583
}
35873584
}
35883585

@@ -4176,7 +4173,7 @@ bool Fortran::lower::markOpenMPDeferredDeclareTargetFunctions(
41764173
deviceCodeFound = true;
41774174

41784175
markDeclareTarget(op, converter, declTar.declareTargetCaptureClause,
4179-
devType);
4176+
devType, declTar.automap);
41804177
}
41814178

41824179
return deviceCodeFound;

flang/lib/Lower/OpenMP/Utils.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,9 +102,10 @@ getIterationVariableSymbol(const lower::pft::Evaluation &eval) {
102102

103103
void gatherFuncAndVarSyms(
104104
const ObjectList &objects, mlir::omp::DeclareTargetCaptureClause clause,
105-
llvm::SmallVectorImpl<DeclareTargetCapturePair> &symbolAndClause) {
105+
llvm::SmallVectorImpl<DeclareTargetCaptureInfo> &symbolAndClause,
106+
bool automap) {
106107
for (const Object &object : objects)
107-
symbolAndClause.emplace_back(clause, *object.sym());
108+
symbolAndClause.emplace_back(clause, *object.sym(), automap);
108109
}
109110

110111
mlir::omp::MapInfoOp

flang/lib/Lower/OpenMP/Utils.h

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,15 @@ class AbstractConverter;
4242

4343
namespace omp {
4444

45-
using DeclareTargetCapturePair =
46-
std::pair<mlir::omp::DeclareTargetCaptureClause, const semantics::Symbol &>;
45+
struct DeclareTargetCaptureInfo {
46+
mlir::omp::DeclareTargetCaptureClause clause;
47+
bool automap = false;
48+
const semantics::Symbol &symbol;
49+
50+
DeclareTargetCaptureInfo(mlir::omp::DeclareTargetCaptureClause c,
51+
const semantics::Symbol &s, bool a = false)
52+
: clause(c), automap(a), symbol(s) {}
53+
};
4754

4855
// A small helper structure for keeping track of a component members MapInfoOp
4956
// and index data when lowering OpenMP map clauses. Keeps track of the
@@ -150,7 +157,8 @@ getIterationVariableSymbol(const lower::pft::Evaluation &eval);
150157

151158
void gatherFuncAndVarSyms(
152159
const ObjectList &objects, mlir::omp::DeclareTargetCaptureClause clause,
153-
llvm::SmallVectorImpl<DeclareTargetCapturePair> &symbolAndClause);
160+
llvm::SmallVectorImpl<DeclareTargetCaptureInfo> &symbolAndClause,
161+
bool automap = false);
154162

155163
int64_t getCollapseValue(const List<Clause> &clauses);
156164

flang/lib/Optimizer/OpenMP/FunctionFiltering.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,8 +95,9 @@ class FunctionFilteringPass
9595
return WalkResult::skip();
9696
}
9797
if (declareTargetOp)
98-
declareTargetOp.setDeclareTarget(declareType,
99-
omp::DeclareTargetCaptureClause::to);
98+
declareTargetOp.setDeclareTarget(
99+
declareType, omp::DeclareTargetCaptureClause::to,
100+
declareTargetOp.getDeclareTargetAutomap());
100101
}
101102
return WalkResult::advance();
102103
});

flang/lib/Optimizer/OpenMP/MarkDeclareTarget.cpp

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ class MarkDeclareTargetPass
3333

3434
void markNestedFuncs(mlir::omp::DeclareTargetDeviceType parentDevTy,
3535
mlir::omp::DeclareTargetCaptureClause parentCapClause,
36-
mlir::Operation *currOp,
36+
bool parentAutomap, mlir::Operation *currOp,
3737
llvm::SmallPtrSet<mlir::Operation *, 16> visited) {
3838
if (visited.contains(currOp))
3939
return;
@@ -57,13 +57,16 @@ class MarkDeclareTargetPass
5757
currentDt != mlir::omp::DeclareTargetDeviceType::any) {
5858
current.setDeclareTarget(
5959
mlir::omp::DeclareTargetDeviceType::any,
60-
current.getDeclareTargetCaptureClause());
60+
current.getDeclareTargetCaptureClause(),
61+
current.getDeclareTargetAutomap());
6162
}
6263
} else {
63-
current.setDeclareTarget(parentDevTy, parentCapClause);
64+
current.setDeclareTarget(parentDevTy, parentCapClause,
65+
parentAutomap);
6466
}
6567

66-
markNestedFuncs(parentDevTy, parentCapClause, currFOp, visited);
68+
markNestedFuncs(parentDevTy, parentCapClause, parentAutomap,
69+
currFOp, visited);
6770
}
6871
}
6972
}
@@ -81,7 +84,8 @@ class MarkDeclareTargetPass
8184
llvm::SmallPtrSet<mlir::Operation *, 16> visited;
8285
markNestedFuncs(declareTargetOp.getDeclareTargetDeviceType(),
8386
declareTargetOp.getDeclareTargetCaptureClause(),
84-
functionOp, visited);
87+
declareTargetOp.getDeclareTargetAutomap(), functionOp,
88+
visited);
8589
}
8690
}
8791

@@ -92,9 +96,10 @@ class MarkDeclareTargetPass
9296
// the contents of the device clause
9397
getOperation()->walk([&](mlir::omp::TargetOp tarOp) {
9498
llvm::SmallPtrSet<mlir::Operation *, 16> visited;
95-
markNestedFuncs(mlir::omp::DeclareTargetDeviceType::nohost,
96-
mlir::omp::DeclareTargetCaptureClause::to, tarOp,
97-
visited);
99+
markNestedFuncs(
100+
/*parentDevTy=*/mlir::omp::DeclareTargetDeviceType::nohost,
101+
/*parentCapClause=*/mlir::omp::DeclareTargetCaptureClause::to,
102+
/*parentAutomap=*/false, tarOp, visited);
98103
});
99104
}
100105
};

flang/test/Lower/OpenMP/common-block-map.f90

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
!RUN: %flang_fc1 -emit-hlfir -fopenmp %s -o - | FileCheck %s
22

33
!CHECK: fir.global common @var_common_(dense<0> : vector<8xi8>) {{.*}} : !fir.array<8xi8>
4-
!CHECK: fir.global common @var_common_link_(dense<0> : vector<8xi8>) {{{.*}} omp.declare_target = #omp.declaretarget<device_type = (any), capture_clause = (link)>} : !fir.array<8xi8>
4+
!CHECK: fir.global common @var_common_link_(dense<0> : vector<8xi8>) {{{.*}} omp.declare_target = #omp.declaretarget<device_type = (any), capture_clause = (link), automap = false>} : !fir.array<8xi8>
55

66
!CHECK-LABEL: func.func @_QPmap_full_block
77
!CHECK: %[[CB_ADDR:.*]] = fir.address_of(@var_common_) : !fir.ref<!fir.array<8xi8>>

0 commit comments

Comments
 (0)