Skip to content

Commit ae81104

Browse files
committed
Initial implementation of tiling.
1 parent 8cc22ee commit ae81104

File tree

14 files changed

+458
-86
lines changed

14 files changed

+458
-86
lines changed

flang/include/flang/Lower/OpenMP.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,6 @@ void genOpenMPDeclarativeConstruct(AbstractConverter &,
8080
void genOpenMPSymbolProperties(AbstractConverter &converter,
8181
const pft::Variable &var);
8282

83-
int64_t getCollapseValue(const Fortran::parser::OmpClauseList &clauseList);
8483
void genThreadprivateOp(AbstractConverter &, const pft::Variable &);
8584
void genDeclareTargetIntGlobal(AbstractConverter &, const pft::Variable &);
8685
bool isOpenMPTargetConstruct(const parser::OpenMPConstruct &);

flang/lib/Lower/OpenMP/OpenMP.cpp

Lines changed: 55 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -404,6 +404,7 @@ static void processHostEvalClauses(lower::AbstractConverter &converter,
404404
return;
405405

406406
const parser::OmpClauseList *beginClauseList = nullptr;
407+
const parser::OmpClauseList *middleClauseList = nullptr;
407408
const parser::OmpClauseList *endClauseList = nullptr;
408409
common::visit(
409410
common::visitors{
@@ -418,6 +419,22 @@ static void processHostEvalClauses(lower::AbstractConverter &converter,
418419
beginClauseList =
419420
&std::get<parser::OmpClauseList>(beginDirective.t);
420421

422+
// FIXME(JAN): For now we check if there is an inner
423+
// OpenMPLoopConstruct, and extract the size clause from there
424+
const auto &innerOptional = std::get<std::optional<
425+
common::Indirection<parser::OpenMPLoopConstruct>>>(
426+
ompConstruct.t);
427+
if (innerOptional.has_value()) {
428+
const auto &innerLoopDirective = innerOptional.value().value();
429+
const auto &innerBegin =
430+
std::get<parser::OmpBeginLoopDirective>(innerLoopDirective.t);
431+
const auto &innerDirective =
432+
std::get<parser::OmpLoopDirective>(innerBegin.t);
433+
if (innerDirective.v == llvm::omp::Directive::OMPD_tile) {
434+
middleClauseList =
435+
&std::get<parser::OmpClauseList>(innerBegin.t);
436+
}
437+
}
421438
if (auto &endDirective =
422439
std::get<std::optional<parser::OmpEndLoopDirective>>(
423440
ompConstruct.t)) {
@@ -431,6 +448,9 @@ static void processHostEvalClauses(lower::AbstractConverter &converter,
431448
assert(beginClauseList && "expected begin directive");
432449
clauses.append(makeClauses(*beginClauseList, semaCtx));
433450

451+
if (middleClauseList)
452+
clauses.append(makeClauses(*middleClauseList, semaCtx));
453+
434454
if (endClauseList)
435455
clauses.append(makeClauses(*endClauseList, semaCtx));
436456
};
@@ -910,6 +930,7 @@ static void genLoopVars(
910930
storeOp =
911931
createAndSetPrivatizedLoopVar(converter, loc, indexVal, argSymbol);
912932
}
933+
913934
firOpBuilder.setInsertionPointAfter(storeOp);
914935
}
915936

@@ -1660,6 +1681,23 @@ genLoopNestClauses(lower::AbstractConverter &converter,
16601681
cp.processCollapse(loc, eval, clauseOps, iv);
16611682

16621683
clauseOps.loopInclusive = converter.getFirOpBuilder().getUnitAttr();
1684+
1685+
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
1686+
for (auto &clause : clauses) {
1687+
if (clause.id == llvm::omp::Clause::OMPC_collapse) {
1688+
const auto &collapse = std::get<clause::Collapse>(clause.u);
1689+
int64_t collapseValue = evaluate::ToInt64(collapse.v).value();
1690+
clauseOps.numCollapse = firOpBuilder.getI64IntegerAttr(collapseValue);
1691+
} else if (clause.id == llvm::omp::Clause::OMPC_sizes) {
1692+
const auto &sizes = std::get<clause::Sizes>(clause.u);
1693+
llvm::SmallVector<int64_t> sizeValues;
1694+
for (auto &size : sizes.v) {
1695+
int64_t sizeValue = evaluate::ToInt64(size).value();
1696+
sizeValues.push_back(sizeValue);
1697+
}
1698+
clauseOps.tileSizes = sizeValues;
1699+
}
1700+
}
16631701
}
16641702

16651703
static void genLoopClauses(
@@ -2036,9 +2074,9 @@ static mlir::omp::LoopNestOp genLoopNestOp(
20362074
return llvm::SmallVector<const semantics::Symbol *>(iv);
20372075
};
20382076

2039-
auto *nestedEval =
2040-
getCollapsedLoopEval(eval, getCollapseValue(item->clauses));
2041-
2077+
uint64_t nestValue = getCollapseValue(item->clauses);
2078+
nestValue = nestValue < iv.size() ? iv.size() : nestValue;
2079+
auto *nestedEval = getCollapsedLoopEval(eval, nestValue);
20422080
return genOpWithBody<mlir::omp::LoopNestOp>(
20432081
OpWithBodyGenInfo(converter, symTable, semaCtx, loc, *nestedEval,
20442082
directive)
@@ -3890,6 +3928,20 @@ static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable,
38903928
std::get<parser::OmpBeginLoopDirective>(loopConstruct.t);
38913929
List<Clause> clauses = makeClauses(
38923930
std::get<parser::OmpClauseList>(beginLoopDirective.t), semaCtx);
3931+
3932+
const auto &innerOptional = std::get<std::optional<common::Indirection<parser::OpenMPLoopConstruct>>>(loopConstruct.t);
3933+
if (innerOptional.has_value()) {
3934+
const auto &innerLoopDirective = innerOptional.value().value();
3935+
const auto &innerBegin =
3936+
std::get<parser::OmpBeginLoopDirective>(innerLoopDirective.t);
3937+
const auto &innerDirective =
3938+
std::get<parser::OmpLoopDirective>(innerBegin.t);
3939+
if (innerDirective.v == llvm::omp::Directive::OMPD_tile) {
3940+
clauses.append(
3941+
makeClauses(std::get<parser::OmpClauseList>(innerBegin.t), semaCtx));
3942+
}
3943+
}
3944+
38933945
if (auto &endLoopDirective =
38943946
std::get<std::optional<parser::OmpEndLoopDirective>>(
38953947
loopConstruct.t)) {
@@ -4021,18 +4073,6 @@ void Fortran::lower::genOpenMPSymbolProperties(
40214073
lower::genDeclareTargetIntGlobal(converter, var);
40224074
}
40234075

4024-
int64_t
4025-
Fortran::lower::getCollapseValue(const parser::OmpClauseList &clauseList) {
4026-
for (const parser::OmpClause &clause : clauseList.v) {
4027-
if (const auto &collapseClause =
4028-
std::get_if<parser::OmpClause::Collapse>(&clause.u)) {
4029-
const auto *expr = semantics::GetExpr(collapseClause->v);
4030-
return evaluate::ToInt64(*expr).value();
4031-
}
4032-
}
4033-
return 1;
4034-
}
4035-
40364076
void Fortran::lower::genThreadprivateOp(lower::AbstractConverter &converter,
40374077
const lower::pft::Variable &var) {
40384078
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();

flang/lib/Lower/OpenMP/Utils.cpp

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -38,14 +38,22 @@ namespace lower {
3838
namespace omp {
3939

4040
int64_t getCollapseValue(const List<Clause> &clauses) {
41-
auto iter = llvm::find_if(clauses, [](const Clause &clause) {
42-
return clause.id == llvm::omp::Clause::OMPC_collapse;
43-
});
44-
if (iter != clauses.end()) {
45-
const auto &collapse = std::get<clause::Collapse>(iter->u);
46-
return evaluate::ToInt64(collapse.v).value();
41+
int64_t collapseValue = 1;
42+
int64_t numTileSizes = 0;
43+
for (auto &clause : clauses) {
44+
if (clause.id == llvm::omp::Clause::OMPC_collapse) {
45+
const auto &collapse = std::get<clause::Collapse>(clause.u);
46+
collapseValue = evaluate::ToInt64(collapse.v).value();
47+
} else if (clause.id == llvm::omp::Clause::OMPC_sizes) {
48+
const auto &sizes = std::get<clause::Sizes>(clause.u);
49+
numTileSizes = sizes.v.size();
50+
}
4751
}
48-
return 1;
52+
53+
collapseValue = collapseValue - numTileSizes;
54+
int64_t result =
55+
collapseValue > numTileSizes ? collapseValue : numTileSizes;
56+
return result;
4957
}
5058

5159
void genObjectList(const ObjectList &objects,
@@ -613,6 +621,7 @@ bool collectLoopRelatedInfo(
613621
lower::pft::Evaluation &eval, const omp::List<omp::Clause> &clauses,
614622
mlir::omp::LoopRelatedClauseOps &result,
615623
llvm::SmallVectorImpl<const semantics::Symbol *> &iv) {
624+
616625
bool found = false;
617626
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
618627

@@ -628,7 +637,16 @@ bool collectLoopRelatedInfo(
628637
collapseValue = evaluate::ToInt64(clause->v).value();
629638
found = true;
630639
}
640+
std::int64_t sizesLengthValue = 0l;
641+
if (auto *clause =
642+
ClauseFinder::findUniqueClause<omp::clause::Sizes>(clauses)) {
643+
sizesLengthValue = clause->v.size();
644+
found = true;
645+
}
631646

647+
collapseValue = collapseValue - sizesLengthValue;
648+
collapseValue =
649+
collapseValue < sizesLengthValue ? sizesLengthValue : collapseValue;
632650
std::size_t loopVarTypeSize = 0;
633651
do {
634652
lower::pft::Evaluation *doLoop =
@@ -661,7 +679,6 @@ bool collectLoopRelatedInfo(
661679
} while (collapseValue > 0);
662680

663681
convertLoopBounds(converter, currentLocation, result, loopVarTypeSize);
664-
665682
return found;
666683
}
667684

flang/lib/Semantics/canonicalize-omp.cpp

Lines changed: 34 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include "flang/Parser/parse-tree.h"
1212
#include "flang/Semantics/semantics.h"
1313

14+
# include <stack>
1415
// After Loop Canonicalization, rewrite OpenMP parse tree to make OpenMP
1516
// Constructs more structured which provide explicit scopes for later
1617
// structural checks and semantic analysis.
@@ -117,15 +118,17 @@ class CanonicalizationOfOmp {
117118
// in the same iteration
118119
//
119120
// Original:
120-
// ExecutableConstruct -> OpenMPConstruct -> OpenMPLoopConstruct
121-
// OmpBeginLoopDirective
121+
// ExecutableConstruct -> OpenMPConstruct -> OpenMPLoopConstruct t->
122+
// OmpBeginLoopDirective t-> OmpLoopDirective
123+
// [ExecutableConstruct -> OpenMPConstruct -> OpenMPLoopConstruct u->
124+
/// OmpBeginLoopDirective t-> OmpLoopDirective t-> Tile v-> OMP_tile]
122125
// ExecutableConstruct -> DoConstruct
126+
// [ExecutableConstruct -> OmpEndLoopDirective]
123127
// ExecutableConstruct -> OmpEndLoopDirective (if available)
124128
//
125129
// After rewriting:
126-
// ExecutableConstruct -> OpenMPConstruct -> OpenMPLoopConstruct
127-
// OmpBeginLoopDirective
128-
// DoConstruct
130+
// ExecutableConstruct -> OpenMPConstruct -> OpenMPLoopConstruct t->
131+
// OmpBeginLoopDirective t -> OmpLoopDirective -> DoConstruct
129132
// OmpEndLoopDirective (if available)
130133
parser::Block::iterator nextIt;
131134
auto &beginDir{std::get<parser::OmpBeginLoopDirective>(x.t)};
@@ -147,20 +150,41 @@ class CanonicalizationOfOmp {
147150
if (GetConstructIf<parser::CompilerDirective>(*nextIt))
148151
continue;
149152

153+
// Keep track of the loops to handle the end loop directives
154+
std::stack<parser::OpenMPLoopConstruct *> loops;
155+
loops.push(&x);
156+
while (auto *innerConstruct{
157+
GetConstructIf<parser::OpenMPConstruct>(*nextIt)}) {
158+
if (auto *innerOmpLoop{
159+
std::get_if<parser::OpenMPLoopConstruct>(&innerConstruct->u)}) {
160+
std::get<
161+
std::optional<common::Indirection<parser::OpenMPLoopConstruct>>>(
162+
loops.top()->t) = std::move(*innerOmpLoop);
163+
// Retrieveing the address so that DoConstruct or inner loop can be
164+
// set later.
165+
loops.push(&(std::get<std::optional<
166+
common::Indirection<parser::OpenMPLoopConstruct>>>(
167+
loops.top()->t)
168+
.value()
169+
.value()));
170+
nextIt = block.erase(nextIt);
171+
}
172+
}
150173
if (auto *doCons{GetConstructIf<parser::DoConstruct>(*nextIt)}) {
151174
if (doCons->GetLoopControl()) {
152175
// move DoConstruct
153176
std::get<std::optional<std::variant<parser::DoConstruct,
154-
common::Indirection<parser::OpenMPLoopConstruct>>>>(x.t) =
155-
std::move(*doCons);
177+
common::Indirection<parser::OpenMPLoopConstruct>>>>(
178+
loops.top()->t) = std::move(*doCons);
156179
nextIt = block.erase(nextIt);
157180
// try to match OmpEndLoopDirective
158-
if (nextIt != block.end()) {
181+
while (nextIt != block.end() && !loops.empty()) {
159182
if (auto *endDir{
160183
GetConstructIf<parser::OmpEndLoopDirective>(*nextIt)}) {
161-
std::get<std::optional<parser::OmpEndLoopDirective>>(x.t) =
162-
std::move(*endDir);
184+
std::get<std::optional<parser::OmpEndLoopDirective>>(
185+
loops.top()->t) = std::move(*endDir);
163186
nextIt = block.erase(nextIt);
187+
loops.pop();
164188
}
165189
}
166190
} else {

0 commit comments

Comments
 (0)