Skip to content

Commit 77e5be4

Browse files
committed
Initial implementation of tiling.
1 parent ceda56b commit 77e5be4

File tree

14 files changed

+458
-86
lines changed

14 files changed

+458
-86
lines changed

flang/include/flang/Lower/OpenMP.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,6 @@ void genOpenMPDeclarativeConstruct(AbstractConverter &,
7979
void genOpenMPSymbolProperties(AbstractConverter &converter,
8080
const pft::Variable &var);
8181

82-
int64_t getCollapseValue(const Fortran::parser::OmpClauseList &clauseList);
8382
void genThreadprivateOp(AbstractConverter &, const pft::Variable &);
8483
void genDeclareTargetIntGlobal(AbstractConverter &, const pft::Variable &);
8584
bool isOpenMPTargetConstruct(const parser::OpenMPConstruct &);

flang/lib/Lower/OpenMP/OpenMP.cpp

Lines changed: 55 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -403,6 +403,7 @@ static void processHostEvalClauses(lower::AbstractConverter &converter,
403403
return;
404404

405405
const parser::OmpClauseList *beginClauseList = nullptr;
406+
const parser::OmpClauseList *middleClauseList = nullptr;
406407
const parser::OmpClauseList *endClauseList = nullptr;
407408
common::visit(
408409
common::visitors{
@@ -417,6 +418,22 @@ static void processHostEvalClauses(lower::AbstractConverter &converter,
417418
beginClauseList =
418419
&std::get<parser::OmpClauseList>(beginDirective.t);
419420

421+
// FIXME(JAN): For now we check if there is an inner
422+
// OpenMPLoopConstruct, and extract the size clause from there
423+
const auto &innerOptional = std::get<std::optional<
424+
common::Indirection<parser::OpenMPLoopConstruct>>>(
425+
ompConstruct.t);
426+
if (innerOptional.has_value()) {
427+
const auto &innerLoopDirective = innerOptional.value().value();
428+
const auto &innerBegin =
429+
std::get<parser::OmpBeginLoopDirective>(innerLoopDirective.t);
430+
const auto &innerDirective =
431+
std::get<parser::OmpLoopDirective>(innerBegin.t);
432+
if (innerDirective.v == llvm::omp::Directive::OMPD_tile) {
433+
middleClauseList =
434+
&std::get<parser::OmpClauseList>(innerBegin.t);
435+
}
436+
}
420437
if (auto &endDirective =
421438
std::get<std::optional<parser::OmpEndLoopDirective>>(
422439
ompConstruct.t)) {
@@ -430,6 +447,9 @@ static void processHostEvalClauses(lower::AbstractConverter &converter,
430447
assert(beginClauseList && "expected begin directive");
431448
clauses.append(makeClauses(*beginClauseList, semaCtx));
432449

450+
if (middleClauseList)
451+
clauses.append(makeClauses(*middleClauseList, semaCtx));
452+
433453
if (endClauseList)
434454
clauses.append(makeClauses(*endClauseList, semaCtx));
435455
};
@@ -910,6 +930,7 @@ static void genLoopVars(
910930
storeOp =
911931
createAndSetPrivatizedLoopVar(converter, loc, indexVal, argSymbol);
912932
}
933+
913934
firOpBuilder.setInsertionPointAfter(storeOp);
914935
}
915936

@@ -1660,6 +1681,23 @@ genLoopNestClauses(lower::AbstractConverter &converter,
16601681
cp.processCollapse(loc, eval, clauseOps, iv);
16611682

16621683
clauseOps.loopInclusive = converter.getFirOpBuilder().getUnitAttr();
1684+
1685+
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
1686+
for (auto &clause : clauses) {
1687+
if (clause.id == llvm::omp::Clause::OMPC_collapse) {
1688+
const auto &collapse = std::get<clause::Collapse>(clause.u);
1689+
int64_t collapseValue = evaluate::ToInt64(collapse.v).value();
1690+
clauseOps.numCollapse = firOpBuilder.getI64IntegerAttr(collapseValue);
1691+
} else if (clause.id == llvm::omp::Clause::OMPC_sizes) {
1692+
const auto &sizes = std::get<clause::Sizes>(clause.u);
1693+
llvm::SmallVector<int64_t> sizeValues;
1694+
for (auto &size : sizes.v) {
1695+
int64_t sizeValue = evaluate::ToInt64(size).value();
1696+
sizeValues.push_back(sizeValue);
1697+
}
1698+
clauseOps.tileSizes = sizeValues;
1699+
}
1700+
}
16631701
}
16641702

16651703
static void genLoopClauses(
@@ -2036,9 +2074,9 @@ static mlir::omp::LoopNestOp genLoopNestOp(
20362074
return llvm::SmallVector<const semantics::Symbol *>(iv);
20372075
};
20382076

2039-
auto *nestedEval =
2040-
getCollapsedLoopEval(eval, getCollapseValue(item->clauses));
2041-
2077+
uint64_t nestValue = getCollapseValue(item->clauses);
2078+
nestValue = nestValue < iv.size() ? iv.size() : nestValue;
2079+
auto *nestedEval = getCollapsedLoopEval(eval, nestValue);
20422080
return genOpWithBody<mlir::omp::LoopNestOp>(
20432081
OpWithBodyGenInfo(converter, symTable, semaCtx, loc, *nestedEval,
20442082
directive)
@@ -3863,6 +3901,20 @@ static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable,
38633901
std::get<parser::OmpBeginLoopDirective>(loopConstruct.t);
38643902
List<Clause> clauses = makeClauses(
38653903
std::get<parser::OmpClauseList>(beginLoopDirective.t), semaCtx);
3904+
3905+
const auto &innerOptional = std::get<std::optional<common::Indirection<parser::OpenMPLoopConstruct>>>(loopConstruct.t);
3906+
if (innerOptional.has_value()) {
3907+
const auto &innerLoopDirective = innerOptional.value().value();
3908+
const auto &innerBegin =
3909+
std::get<parser::OmpBeginLoopDirective>(innerLoopDirective.t);
3910+
const auto &innerDirective =
3911+
std::get<parser::OmpLoopDirective>(innerBegin.t);
3912+
if (innerDirective.v == llvm::omp::Directive::OMPD_tile) {
3913+
clauses.append(
3914+
makeClauses(std::get<parser::OmpClauseList>(innerBegin.t), semaCtx));
3915+
}
3916+
}
3917+
38663918
if (auto &endLoopDirective =
38673919
std::get<std::optional<parser::OmpEndLoopDirective>>(
38683920
loopConstruct.t)) {
@@ -3994,18 +4046,6 @@ void Fortran::lower::genOpenMPSymbolProperties(
39944046
lower::genDeclareTargetIntGlobal(converter, var);
39954047
}
39964048

3997-
int64_t
3998-
Fortran::lower::getCollapseValue(const parser::OmpClauseList &clauseList) {
3999-
for (const parser::OmpClause &clause : clauseList.v) {
4000-
if (const auto &collapseClause =
4001-
std::get_if<parser::OmpClause::Collapse>(&clause.u)) {
4002-
const auto *expr = semantics::GetExpr(collapseClause->v);
4003-
return evaluate::ToInt64(*expr).value();
4004-
}
4005-
}
4006-
return 1;
4007-
}
4008-
40094049
void Fortran::lower::genThreadprivateOp(lower::AbstractConverter &converter,
40104050
const lower::pft::Variable &var) {
40114051
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();

flang/lib/Lower/OpenMP/Utils.cpp

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -38,14 +38,22 @@ namespace lower {
3838
namespace omp {
3939

4040
int64_t getCollapseValue(const List<Clause> &clauses) {
41-
auto iter = llvm::find_if(clauses, [](const Clause &clause) {
42-
return clause.id == llvm::omp::Clause::OMPC_collapse;
43-
});
44-
if (iter != clauses.end()) {
45-
const auto &collapse = std::get<clause::Collapse>(iter->u);
46-
return evaluate::ToInt64(collapse.v).value();
41+
int64_t collapseValue = 1;
42+
int64_t numTileSizes = 0;
43+
for (auto &clause : clauses) {
44+
if (clause.id == llvm::omp::Clause::OMPC_collapse) {
45+
const auto &collapse = std::get<clause::Collapse>(clause.u);
46+
collapseValue = evaluate::ToInt64(collapse.v).value();
47+
} else if (clause.id == llvm::omp::Clause::OMPC_sizes) {
48+
const auto &sizes = std::get<clause::Sizes>(clause.u);
49+
numTileSizes = sizes.v.size();
50+
}
4751
}
48-
return 1;
52+
53+
collapseValue = collapseValue - numTileSizes;
54+
int64_t result =
55+
collapseValue > numTileSizes ? collapseValue : numTileSizes;
56+
return result;
4957
}
5058

5159
void genObjectList(const ObjectList &objects,
@@ -612,6 +620,7 @@ bool collectLoopRelatedInfo(
612620
lower::pft::Evaluation &eval, const omp::List<omp::Clause> &clauses,
613621
mlir::omp::LoopRelatedClauseOps &result,
614622
llvm::SmallVectorImpl<const semantics::Symbol *> &iv) {
623+
615624
bool found = false;
616625
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
617626

@@ -627,7 +636,16 @@ bool collectLoopRelatedInfo(
627636
collapseValue = evaluate::ToInt64(clause->v).value();
628637
found = true;
629638
}
639+
std::int64_t sizesLengthValue = 0l;
640+
if (auto *clause =
641+
ClauseFinder::findUniqueClause<omp::clause::Sizes>(clauses)) {
642+
sizesLengthValue = clause->v.size();
643+
found = true;
644+
}
630645

646+
collapseValue = collapseValue - sizesLengthValue;
647+
collapseValue =
648+
collapseValue < sizesLengthValue ? sizesLengthValue : collapseValue;
631649
std::size_t loopVarTypeSize = 0;
632650
do {
633651
lower::pft::Evaluation *doLoop =
@@ -660,7 +678,6 @@ bool collectLoopRelatedInfo(
660678
} while (collapseValue > 0);
661679

662680
convertLoopBounds(converter, currentLocation, result, loopVarTypeSize);
663-
664681
return found;
665682
}
666683

flang/lib/Semantics/canonicalize-omp.cpp

Lines changed: 34 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include "flang/Parser/parse-tree.h"
1212
#include "flang/Semantics/semantics.h"
1313

14+
# include <stack>
1415
// After Loop Canonicalization, rewrite OpenMP parse tree to make OpenMP
1516
// Constructs more structured which provide explicit scopes for later
1617
// structural checks and semantic analysis.
@@ -117,15 +118,17 @@ class CanonicalizationOfOmp {
117118
// in the same iteration
118119
//
119120
// Original:
120-
// ExecutableConstruct -> OpenMPConstruct -> OpenMPLoopConstruct
121-
// OmpBeginLoopDirective
121+
// ExecutableConstruct -> OpenMPConstruct -> OpenMPLoopConstruct t->
122+
// OmpBeginLoopDirective t-> OmpLoopDirective
123+
// [ExecutableConstruct -> OpenMPConstruct -> OpenMPLoopConstruct u->
124+
/// OmpBeginLoopDirective t-> OmpLoopDirective t-> Tile v-> OMP_tile]
122125
// ExecutableConstruct -> DoConstruct
126+
// [ExecutableConstruct -> OmpEndLoopDirective]
123127
// ExecutableConstruct -> OmpEndLoopDirective (if available)
124128
//
125129
// After rewriting:
126-
// ExecutableConstruct -> OpenMPConstruct -> OpenMPLoopConstruct
127-
// OmpBeginLoopDirective
128-
// DoConstruct
130+
// ExecutableConstruct -> OpenMPConstruct -> OpenMPLoopConstruct t->
131+
// OmpBeginLoopDirective t -> OmpLoopDirective -> DoConstruct
129132
// OmpEndLoopDirective (if available)
130133
parser::Block::iterator nextIt;
131134
auto &beginDir{std::get<parser::OmpBeginLoopDirective>(x.t)};
@@ -147,20 +150,41 @@ class CanonicalizationOfOmp {
147150
if (GetConstructIf<parser::CompilerDirective>(*nextIt))
148151
continue;
149152

153+
// Keep track of the loops to handle the end loop directives
154+
std::stack<parser::OpenMPLoopConstruct *> loops;
155+
loops.push(&x);
156+
while (auto *innerConstruct{
157+
GetConstructIf<parser::OpenMPConstruct>(*nextIt)}) {
158+
if (auto *innerOmpLoop{
159+
std::get_if<parser::OpenMPLoopConstruct>(&innerConstruct->u)}) {
160+
std::get<
161+
std::optional<common::Indirection<parser::OpenMPLoopConstruct>>>(
162+
loops.top()->t) = std::move(*innerOmpLoop);
163+
// Retrieveing the address so that DoConstruct or inner loop can be
164+
// set later.
165+
loops.push(&(std::get<std::optional<
166+
common::Indirection<parser::OpenMPLoopConstruct>>>(
167+
loops.top()->t)
168+
.value()
169+
.value()));
170+
nextIt = block.erase(nextIt);
171+
}
172+
}
150173
if (auto *doCons{GetConstructIf<parser::DoConstruct>(*nextIt)}) {
151174
if (doCons->GetLoopControl()) {
152175
// move DoConstruct
153176
std::get<std::optional<std::variant<parser::DoConstruct,
154-
common::Indirection<parser::OpenMPLoopConstruct>>>>(x.t) =
155-
std::move(*doCons);
177+
common::Indirection<parser::OpenMPLoopConstruct>>>>(
178+
loops.top()->t) = std::move(*doCons);
156179
nextIt = block.erase(nextIt);
157180
// try to match OmpEndLoopDirective
158-
if (nextIt != block.end()) {
181+
while (nextIt != block.end() && !loops.empty()) {
159182
if (auto *endDir{
160183
GetConstructIf<parser::OmpEndLoopDirective>(*nextIt)}) {
161-
std::get<std::optional<parser::OmpEndLoopDirective>>(x.t) =
162-
std::move(*endDir);
184+
std::get<std::optional<parser::OmpEndLoopDirective>>(
185+
loops.top()->t) = std::move(*endDir);
163186
nextIt = block.erase(nextIt);
187+
loops.pop();
164188
}
165189
}
166190
} else {

0 commit comments

Comments
 (0)