Skip to content

Commit 4e98772

Browse files
committed
[TSAR, Transform] Add handling of nested pragmas.
1 parent 169a53f commit 4e98772

File tree

2 files changed

+103
-171
lines changed

2 files changed

+103
-171
lines changed

include/tsar/Transform/Clang/LoopSwapping.h

Lines changed: 0 additions & 91 deletions
This file was deleted.

lib/Transform/Clang/LoopSwapping.cpp

Lines changed: 103 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -32,14 +32,17 @@
3232
#include "tsar/Core/Query.h"
3333
#include "tsar/Frontend/Clang/Pragma.h"
3434
#include "tsar/Support/Clang/Diagnostic.h"
35+
#include "tsar/Support/Clang/Utils.h"
3536
#include "tsar/Support/Clang/SourceLocationTraverse.h"
3637
#include "tsar/Support/GlobalOptions.h"
3738
#include "tsar/Support/PassAAProvider.h"
3839
#include "tsar/Support/Tags.h"
3940
#include "tsar/Transform/Clang/Passes.h"
4041
#include <clang/AST/RecursiveASTVisitor.h>
42+
#include <llvm/ADT/DenseMap.h>
4143
#include <llvm/Analysis/LoopInfo.h>
4244
#include <vector>
45+
#include <stack>
4346

4447
using namespace llvm;
4548
using namespace clang;
@@ -53,16 +56,9 @@ namespace {
5356
/// This provides access to function-level analysis results on server.
5457
using ClangLoopSwappingProvider =
5558
FunctionPassAAProvider<DIEstimateMemoryPass, DIDependencyAnalysisPass>;
56-
57-
using RangeVector = SmallVector<SourceRange, 2>;
58-
59-
using RangePairVector = std::vector<RangeVector>;
60-
61-
using LoopVector = SmallVector<Loop *, 2>;
62-
63-
using LoopPairVector = std::vector<LoopVector>;
64-
6559
using DIAliasTraitVector = std::vector<const DIAliasTrait *>;
60+
using LoopRangeInfo = std::pair<Loop *, SourceRange>;
61+
using LoopRangeVector = SmallVector<LoopRangeInfo, 2>;
6662

6763
class LoopVisitor : public RecursiveASTVisitor<LoopVisitor> {
6864
private:
@@ -76,26 +72,17 @@ class LoopVisitor : public RecursiveASTVisitor<LoopVisitor> {
7672
, mLangOpts(Rewr.getLangOpts())
7773
, mLoopInfo(LM)
7874
, mState(TraverseState::NONE)
75+
, mCurrentLevel(-1)
7976
{}
8077

81-
void EnterInScope() {
82-
mForLocations.clear();
83-
mForIRs.clear();
84-
}
85-
86-
void ExitFromScope() {
87-
mForLocations.clear();
88-
mForIRs.clear();
89-
}
90-
9178
bool TraverseStmt(Stmt *S) {
9279
if (!S)
9380
return true;
9481
Pragma P(*S);
9582
if (P) {
9683
// Search for loop swapping clause and disable renaming in other pragmas.
9784
if (findClause(P, ClauseId::SwapLoops, mClauses)) {
98-
llvm::SmallVector<clang::CharSourceRange, 8> ToRemove;
85+
SmallVector<CharSourceRange, 8> ToRemove;
9986
auto IsPossible =
10087
pragmaRangeToRemove(P, mClauses, mSrcMgr, mLangOpts, mImportInfo,
10188
ToRemove);
@@ -115,8 +102,12 @@ class LoopVisitor : public RecursiveASTVisitor<LoopVisitor> {
115102
RemoveEmptyLine.RemoveLineIfEmpty = false;
116103
for (auto SR : ToRemove)
117104
mRewriter.RemoveText(SR, RemoveEmptyLine);
105+
mCurrentLevel++;
106+
if (mCurrentLevel + 1 > int(mPragmaLevels.size())) {
107+
mPragmaLevels.resize(mCurrentLevel + 1);
108+
}
109+
mPragmaLevels[mCurrentLevel].push_back(S);
118110
mState = TraverseState::PRAGMA;
119-
EnterInScope();
120111
}
121112
return true;
122113
}
@@ -127,13 +118,12 @@ class LoopVisitor : public RecursiveASTVisitor<LoopVisitor> {
127118
bool TraverseCompoundStmt(CompoundStmt *S) {
128119
if (mState == TraverseState::PRAGMA) {
129120
mState = TraverseState::OUTERFOR;
130-
mForLocations.clear();
131-
mForIRs.clear();
121+
mLoopStack.push(LoopRangeVector());
132122
auto Res = RecursiveASTVisitor::TraverseCompoundStmt(S);
133123
mState = TraverseState::PRAGMA;
134-
mRangePairs.push_back(mForLocations);
135-
mLoopPairs.push_back(mForIRs);
136-
ExitFromScope();
124+
mPragmaLoopsInfo[mPragmaLevels[mCurrentLevel].back()] = mLoopStack.top();
125+
mLoopStack.pop();
126+
mCurrentLevel--;
137127
return Res;
138128
}
139129
return RecursiveASTVisitor::TraverseCompoundStmt(S);
@@ -144,9 +134,8 @@ class LoopVisitor : public RecursiveASTVisitor<LoopVisitor> {
144134
auto Match = mLoopInfo.find<AST>(S);
145135
if (Match != mLoopInfo.end()) {
146136
Loop *MatchLoop = Match->get<IR>();
147-
mForIRs.push_back(MatchLoop);
148137
SourceRange Range(S->getBeginLoc(), S->getEndLoc());
149-
mForLocations.push_back(Range);
138+
mLoopStack.top().push_back(std::make_pair(MatchLoop, Range));
150139
}
151140
mState = TraverseState::INNERFOR;
152141
auto Res = RecursiveASTVisitor::TraverseForStmt(S);
@@ -156,32 +145,42 @@ class LoopVisitor : public RecursiveASTVisitor<LoopVisitor> {
156145
return RecursiveASTVisitor::TraverseForStmt(S);
157146
}
158147

159-
RangePairVector getRangePairs() const {
160-
return mRangePairs;
148+
const DenseMap<Stmt *, LoopRangeVector> &getPragmaLoopsInfo() const {
149+
return mPragmaLoopsInfo;
161150
}
162151

163-
LoopPairVector getLoopPairs() const {
164-
return mLoopPairs;
152+
const std::vector<SmallVector<Stmt *, 1>> &getPragmaLevels() const {
153+
return mPragmaLevels;
165154
}
166155

167-
size_t getLoopCount() const {
168-
return mRangePairs.size();
156+
size_t getMaxPragmaDepth() const {
157+
return mPragmaLevels.size();
169158
}
170159

171160
void printLocations() const {
172-
LLVM_DEBUG(dbgs() << "[LOOP SWAPPING]: 'for' loop locations:\n");
173-
size_t LoopNumber = 0;
174-
for (auto locs: mRangePairs) {
175-
for (auto location : locs) {
176-
LLVM_DEBUG(dbgs() << "Loop #" << LoopNumber << ":\n");
177-
SourceLocation Begin = location.getBegin();
178-
SourceLocation End = location.getEnd();
179-
LLVM_DEBUG(dbgs() << "\tBegin: ");
180-
Begin.print(dbgs(), mSrcMgr);
181-
LLVM_DEBUG(dbgs() << "\n\tEnd: ");
182-
End.print(dbgs(), mSrcMgr);
183-
LLVM_DEBUG(dbgs() << '\n');
184-
LoopNumber++;
161+
for (size_t Level = 0; Level < mPragmaLevels.size(); Level++) {
162+
LLVM_DEBUG(dbgs() << "Level " << Level << " :\n");
163+
const auto &CurrentLevel = mPragmaLevels[Level];
164+
for (size_t PragmaN = 0; PragmaN < CurrentLevel.size(); PragmaN++) {
165+
const auto &Pragma = CurrentLevel[PragmaN];
166+
auto PragmaItr = mPragmaLoopsInfo.find(Pragma);
167+
assert(PragmaItr != mPragmaLoopsInfo.end() &&
168+
"Map should contain all pragmas (as keys) from level vectors.");
169+
const auto &Loops = PragmaItr->second;
170+
LLVM_DEBUG(dbgs() << "\tPragma " << PragmaN << " (" << Pragma <<"):\n");
171+
for (const auto &Info : Loops) {
172+
const auto LoopPtr = Info.first;
173+
const auto &Range = Info.second;
174+
SourceLocation Begin = Range.getBegin();
175+
SourceLocation End = Range.getEnd();
176+
LLVM_DEBUG(dbgs() << "\t\t[Range]\n");
177+
LLVM_DEBUG(dbgs() << "\t\tBegin:" << Begin.printToString(mSrcMgr)
178+
<< "\n");
179+
LLVM_DEBUG(dbgs() << "\t\tEnd:" << End.printToString(mSrcMgr) <<"\n");
180+
LLVM_DEBUG(dbgs() << "\t\t\n\t\t[Loop]\n");
181+
const auto &LoopText = mRewriter.getRewrittenText(Range);
182+
LLVM_DEBUG(dbgs() << "\t\t" << LoopText << "\n\n");
183+
}
185184
}
186185
}
187186
}
@@ -194,10 +193,11 @@ class LoopVisitor : public RecursiveASTVisitor<LoopVisitor> {
194193
const LoopMatcherPass::LoopMatcher &mLoopInfo;
195194
TraverseState mState;
196195
SmallVector<Stmt *, 1> mClauses;
197-
RangeVector mForLocations;
198-
LoopVector mForIRs;
199-
RangePairVector mRangePairs;
200-
LoopPairVector mLoopPairs;
196+
197+
int mCurrentLevel;
198+
std::stack<LoopRangeVector> mLoopStack;
199+
std::vector<SmallVector<Stmt *, 1>> mPragmaLevels;
200+
DenseMap<Stmt *, LoopRangeVector> mPragmaLoopsInfo;
201201
};
202202

203203
class ClangLoopSwapping : public FunctionPass, private bcl::Uncopyable {
@@ -213,11 +213,12 @@ class ClangLoopSwapping : public FunctionPass, private bcl::Uncopyable {
213213
private:
214214
void swapLoops(const LoopVisitor &Visitor);
215215
DIAliasTraitVector getLoopTraits(MDNode *LoopID) const;
216-
bool isSwappingAvailable(const LoopVector &Loops) const;
216+
bool isSwappingAvailable(std::pair<Loop *, Loop *> &Loops) const;
217217
bool hasSameReductionKind(const DIAliasTraitVector &TV0,
218218
const DIAliasTraitVector &TV1) const;
219219
bool hasTrueOrAntiDependence(const DIAliasTraitVector &TV0,
220220
const DIAliasTraitVector &TV1) const;
221+
void removePragmas(const std::vector<SmallVector<Stmt *, 1>> &PragmaLevels);
221222

222223
Function *mFunction = nullptr;
223224
TransformationContext *mTfmCtx = nullptr;
@@ -323,17 +324,18 @@ bool ClangLoopSwapping::hasTrueOrAntiDependence(
323324
return false;
324325
}
325326

326-
bool ClangLoopSwapping::isSwappingAvailable(const LoopVector &Loops) const {
327+
bool ClangLoopSwapping::isSwappingAvailable(
328+
std::pair<Loop *, Loop *> &Loops) const {
327329
auto HasLoopID = [this](MDNode*& LoopID, int LoopN) {
328330
if (!LoopID || !(LoopID = mGetLoopID(LoopID))) {
329331
LLVM_DEBUG(dbgs() << "[LOOP SWAPPING]: ignore loop without ID (loop " <<
330-
LoopN << ").");
332+
LoopN << ").\n");
331333
return false;
332334
}
333335
return true;
334336
};
335-
auto *LoopID0 = Loops[0]->getLoopID();
336-
auto *LoopID1 = Loops[1]->getLoopID();
337+
auto *LoopID0 = Loops.first->getLoopID();
338+
auto *LoopID1 = Loops.second->getLoopID();
337339
if (!HasLoopID(LoopID0, 0))
338340
return false;
339341
if (!HasLoopID(LoopID1, 1))
@@ -353,28 +355,49 @@ bool ClangLoopSwapping::isSwappingAvailable(const LoopVector &Loops) const {
353355
}
354356

355357
void ClangLoopSwapping::swapLoops(const LoopVisitor &Visitor) {
356-
const auto &RangePairs = Visitor.getRangePairs();
357-
const auto &LoopPairs = Visitor.getLoopPairs();
358-
Rewriter &mRewriter = mTfmCtx->getRewriter();
359-
for (size_t i = 0; i < RangePairs.size(); i++) {
360-
const auto &Ranges = RangePairs[i];
361-
const auto &Loops = LoopPairs[i];
362-
if (Ranges.size() < 2) {
363-
toDiag(mSrcMgr->getDiagnostics(),
364-
diag::warn_loop_swapping_missing_loop);
365-
continue;
366-
}
367-
if (Ranges.size() > 2) {
368-
toDiag(mSrcMgr->getDiagnostics(),
369-
diag::warn_loop_swapping_redundant_loop);
370-
}
371-
if (isSwappingAvailable(Loops)) {
372-
const auto &FirstRange = Ranges[0];
373-
const auto &SecondRange = Ranges[1];
374-
const auto &FirstLoop = mRewriter.getRewrittenText(FirstRange);
375-
const auto &SecondLoop = mRewriter.getRewrittenText(SecondRange);
376-
mRewriter.ReplaceText(FirstRange, SecondLoop);
377-
mRewriter.ReplaceText(SecondRange, FirstLoop);
358+
Rewriter &Rewr = mTfmCtx->getRewriter();
359+
auto GetLoopEnd = [this, Rewr](const SourceRange &LoopRange)->SourceLocation {
360+
Token SemiTok;
361+
return (!getRawTokenAfter(LoopRange.getEnd(), *mSrcMgr,
362+
Rewr.getLangOpts(), SemiTok) && SemiTok.is(tok::semi)) ?
363+
SemiTok.getLocation() : LoopRange.getEnd();
364+
};
365+
auto &PragmaLevels = Visitor.getPragmaLevels();
366+
auto &PragmaLoopsInfo = Visitor.getPragmaLoopsInfo();
367+
for (auto it = PragmaLevels.rbegin(); it != PragmaLevels.rend(); it++) {
368+
for (auto &Pragma : *it) {
369+
auto PragmaItr = PragmaLoopsInfo.find(Pragma);
370+
assert(PragmaItr != PragmaLoopsInfo.end() &&
371+
"Map should contain all pragmas (as keys) from level vectors.");
372+
const auto &Loops = PragmaItr->second;
373+
if (Loops.size() < 2) {
374+
toDiag(mSrcMgr->getDiagnostics(),
375+
diag::warn_loop_swapping_missing_loop);
376+
continue;
377+
}
378+
if (Loops.size() > 2) {
379+
toDiag(mSrcMgr->getDiagnostics(),
380+
diag::warn_loop_swapping_redundant_loop);
381+
}
382+
auto Info0 = Loops[0];
383+
auto Info1 = Loops[1];
384+
auto Loop0 = Info0.first;
385+
auto Loop1 = Info1.first;
386+
auto LoopPair = std::make_pair(Loop0, Loop1);
387+
if (isSwappingAvailable(LoopPair)) {
388+
auto &Range0 = Info0.second;
389+
auto &Range1 = Info1.second;
390+
Range0.setEnd(GetLoopEnd(Range0));
391+
Range1.setEnd(GetLoopEnd(Range1));
392+
auto Range0End = Range0.getEnd();
393+
auto Range1Begin = Range1.getBegin();
394+
const auto &LoopText0 = Rewr.getRewrittenText(Range0);
395+
const auto &LoopText1 = Rewr.getRewrittenText(Range1);
396+
Rewr.RemoveText(Range0);
397+
Rewr.RemoveText(Range1);
398+
Rewr.InsertTextBefore(Range0End, LoopText1);
399+
Rewr.InsertTextAfter(Range1Begin, LoopText0);
400+
}
378401
}
379402
}
380403
}
@@ -413,8 +436,8 @@ bool ClangLoopSwapping::runOnFunction(Function &F) {
413436
LoopVisitor Visitor(mTfmCtx->getRewriter(), mLoopInfo, *ImportInfo);
414437
mSrcMgr = &mTfmCtx->getRewriter().getSourceMgr();
415438
Visitor.TraverseDecl(FuncDecl);
416-
if (Visitor.getLoopCount() == 0) {
417-
LLVM_DEBUG(dbgs() << "[LOOP SWAPPING]: no loop found.\n");
439+
if (Visitor.getMaxPragmaDepth() == 0) {
440+
LLVM_DEBUG(dbgs() << "[LOOP SWAPPING]: no pragma found.\n");
418441
return false;
419442
}
420443
Visitor.printLocations();

0 commit comments

Comments
 (0)