3232#include " tsar/Core/Query.h"
3333#include " tsar/Frontend/Clang/Pragma.h"
3434#include " tsar/Support/Clang/Diagnostic.h"
35+ #include " tsar/Support/Clang/Utils.h"
3536#include " tsar/Support/Clang/SourceLocationTraverse.h"
3637#include " tsar/Support/GlobalOptions.h"
3738#include " tsar/Support/PassAAProvider.h"
3839#include " tsar/Support/Tags.h"
3940#include " tsar/Transform/Clang/Passes.h"
4041#include < clang/AST/RecursiveASTVisitor.h>
42+ #include < llvm/ADT/DenseMap.h>
4143#include < llvm/Analysis/LoopInfo.h>
4244#include < vector>
45+ #include < stack>
4346
4447using namespace llvm ;
4548using namespace clang ;
@@ -53,16 +56,9 @@ namespace {
5356// / This provides access to function-level analysis results on server.
5457using ClangLoopSwappingProvider =
5558 FunctionPassAAProvider<DIEstimateMemoryPass, DIDependencyAnalysisPass>;
56-
57- using RangeVector = SmallVector<SourceRange, 2 >;
58-
59- using RangePairVector = std::vector<RangeVector>;
60-
61- using LoopVector = SmallVector<Loop *, 2 >;
62-
63- using LoopPairVector = std::vector<LoopVector>;
64-
6559using DIAliasTraitVector = std::vector<const DIAliasTrait *>;
60+ using LoopRangeInfo = std::pair<Loop *, SourceRange>;
61+ using LoopRangeVector = SmallVector<LoopRangeInfo, 2 >;
6662
6763class LoopVisitor : public RecursiveASTVisitor <LoopVisitor> {
6864private:
@@ -76,26 +72,17 @@ class LoopVisitor : public RecursiveASTVisitor<LoopVisitor> {
7672 , mLangOpts (Rewr.getLangOpts())
7773 , mLoopInfo (LM)
7874 , mState (TraverseState::NONE)
75+ , mCurrentLevel (-1 )
7976 {}
8077
81- void EnterInScope () {
82- mForLocations .clear ();
83- mForIRs .clear ();
84- }
85-
86- void ExitFromScope () {
87- mForLocations .clear ();
88- mForIRs .clear ();
89- }
90-
9178 bool TraverseStmt (Stmt *S) {
9279 if (!S)
9380 return true ;
9481 Pragma P (*S);
9582 if (P) {
9683 // Search for loop swapping clause and disable renaming in other pragmas.
9784 if (findClause (P, ClauseId::SwapLoops, mClauses )) {
98- llvm:: SmallVector<clang:: CharSourceRange, 8 > ToRemove;
85+ SmallVector<CharSourceRange, 8 > ToRemove;
9986 auto IsPossible =
10087 pragmaRangeToRemove (P, mClauses , mSrcMgr , mLangOpts , mImportInfo ,
10188 ToRemove);
@@ -115,8 +102,12 @@ class LoopVisitor : public RecursiveASTVisitor<LoopVisitor> {
115102 RemoveEmptyLine.RemoveLineIfEmpty = false ;
116103 for (auto SR : ToRemove)
117104 mRewriter .RemoveText (SR, RemoveEmptyLine);
105+ mCurrentLevel ++;
106+ if (mCurrentLevel + 1 > int (mPragmaLevels .size ())) {
107+ mPragmaLevels .resize (mCurrentLevel + 1 );
108+ }
109+ mPragmaLevels [mCurrentLevel ].push_back (S);
118110 mState = TraverseState::PRAGMA;
119- EnterInScope ();
120111 }
121112 return true ;
122113 }
@@ -127,13 +118,12 @@ class LoopVisitor : public RecursiveASTVisitor<LoopVisitor> {
127118 bool TraverseCompoundStmt (CompoundStmt *S) {
128119 if (mState == TraverseState::PRAGMA) {
129120 mState = TraverseState::OUTERFOR;
130- mForLocations .clear ();
131- mForIRs .clear ();
121+ mLoopStack .push (LoopRangeVector ());
132122 auto Res = RecursiveASTVisitor::TraverseCompoundStmt (S);
133123 mState = TraverseState::PRAGMA;
134- mRangePairs . push_back ( mForLocations );
135- mLoopPairs . push_back ( mForIRs );
136- ExitFromScope () ;
124+ mPragmaLoopsInfo [ mPragmaLevels [ mCurrentLevel ]. back ()] = mLoopStack . top ( );
125+ mLoopStack . pop ( );
126+ mCurrentLevel -- ;
137127 return Res;
138128 }
139129 return RecursiveASTVisitor::TraverseCompoundStmt (S);
@@ -144,9 +134,8 @@ class LoopVisitor : public RecursiveASTVisitor<LoopVisitor> {
144134 auto Match = mLoopInfo .find <AST>(S);
145135 if (Match != mLoopInfo .end ()) {
146136 Loop *MatchLoop = Match->get <IR>();
147- mForIRs .push_back (MatchLoop);
148137 SourceRange Range (S->getBeginLoc (), S->getEndLoc ());
149- mForLocations . push_back (Range);
138+ mLoopStack . top (). push_back (std::make_pair (MatchLoop, Range) );
150139 }
151140 mState = TraverseState::INNERFOR;
152141 auto Res = RecursiveASTVisitor::TraverseForStmt (S);
@@ -156,32 +145,42 @@ class LoopVisitor : public RecursiveASTVisitor<LoopVisitor> {
156145 return RecursiveASTVisitor::TraverseForStmt (S);
157146 }
158147
159- RangePairVector getRangePairs () const {
160- return mRangePairs ;
148+ const DenseMap<Stmt *, LoopRangeVector> & getPragmaLoopsInfo () const {
149+ return mPragmaLoopsInfo ;
161150 }
162151
163- LoopPairVector getLoopPairs () const {
164- return mLoopPairs ;
152+ const std::vector<SmallVector<Stmt *, 1 >> & getPragmaLevels () const {
153+ return mPragmaLevels ;
165154 }
166155
167- size_t getLoopCount () const {
168- return mRangePairs .size ();
156+ size_t getMaxPragmaDepth () const {
157+ return mPragmaLevels .size ();
169158 }
170159
171160 void printLocations () const {
172- LLVM_DEBUG (dbgs () << " [LOOP SWAPPING]: 'for' loop locations:\n " );
173- size_t LoopNumber = 0 ;
174- for (auto locs: mRangePairs ) {
175- for (auto location : locs) {
176- LLVM_DEBUG (dbgs () << " Loop #" << LoopNumber << " :\n " );
177- SourceLocation Begin = location.getBegin ();
178- SourceLocation End = location.getEnd ();
179- LLVM_DEBUG (dbgs () << " \t Begin: " );
180- Begin.print (dbgs (), mSrcMgr );
181- LLVM_DEBUG (dbgs () << " \n\t End: " );
182- End.print (dbgs (), mSrcMgr );
183- LLVM_DEBUG (dbgs () << ' \n ' );
184- LoopNumber++;
161+ for (size_t Level = 0 ; Level < mPragmaLevels .size (); Level++) {
162+ LLVM_DEBUG (dbgs () << " Level " << Level << " :\n " );
163+ const auto &CurrentLevel = mPragmaLevels [Level];
164+ for (size_t PragmaN = 0 ; PragmaN < CurrentLevel.size (); PragmaN++) {
165+ const auto &Pragma = CurrentLevel[PragmaN];
166+ auto PragmaItr = mPragmaLoopsInfo .find (Pragma);
167+ assert (PragmaItr != mPragmaLoopsInfo .end () &&
168+ " Map should contain all pragmas (as keys) from level vectors." );
169+ const auto &Loops = PragmaItr->second ;
170+ LLVM_DEBUG (dbgs () << " \t Pragma " << PragmaN << " (" << Pragma <<" ):\n " );
171+ for (const auto &Info : Loops) {
172+ const auto LoopPtr = Info.first ;
173+ const auto &Range = Info.second ;
174+ SourceLocation Begin = Range.getBegin ();
175+ SourceLocation End = Range.getEnd ();
176+ LLVM_DEBUG (dbgs () << " \t\t [Range]\n " );
177+ LLVM_DEBUG (dbgs () << " \t\t Begin:" << Begin.printToString (mSrcMgr )
178+ << " \n " );
179+ LLVM_DEBUG (dbgs () << " \t\t End:" << End.printToString (mSrcMgr ) <<" \n " );
180+ LLVM_DEBUG (dbgs () << " \t\t\n\t\t [Loop]\n " );
181+ const auto &LoopText = mRewriter .getRewrittenText (Range);
182+ LLVM_DEBUG (dbgs () << " \t\t " << LoopText << " \n\n " );
183+ }
185184 }
186185 }
187186 }
@@ -194,10 +193,11 @@ class LoopVisitor : public RecursiveASTVisitor<LoopVisitor> {
194193 const LoopMatcherPass::LoopMatcher &mLoopInfo ;
195194 TraverseState mState ;
196195 SmallVector<Stmt *, 1 > mClauses ;
197- RangeVector mForLocations ;
198- LoopVector mForIRs ;
199- RangePairVector mRangePairs ;
200- LoopPairVector mLoopPairs ;
196+
197+ int mCurrentLevel ;
198+ std::stack<LoopRangeVector> mLoopStack ;
199+ std::vector<SmallVector<Stmt *, 1 >> mPragmaLevels ;
200+ DenseMap<Stmt *, LoopRangeVector> mPragmaLoopsInfo ;
201201};
202202
203203class ClangLoopSwapping : public FunctionPass , private bcl ::Uncopyable {
@@ -213,11 +213,12 @@ class ClangLoopSwapping : public FunctionPass, private bcl::Uncopyable {
213213private:
214214 void swapLoops (const LoopVisitor &Visitor);
215215 DIAliasTraitVector getLoopTraits (MDNode *LoopID) const ;
216- bool isSwappingAvailable (const LoopVector &Loops) const ;
216+ bool isSwappingAvailable (std::pair<Loop *, Loop *> &Loops) const ;
217217 bool hasSameReductionKind (const DIAliasTraitVector &TV0,
218218 const DIAliasTraitVector &TV1) const ;
219219 bool hasTrueOrAntiDependence (const DIAliasTraitVector &TV0,
220220 const DIAliasTraitVector &TV1) const ;
221+ void removePragmas (const std::vector<SmallVector<Stmt *, 1 >> &PragmaLevels);
221222
222223 Function *mFunction = nullptr ;
223224 TransformationContext *mTfmCtx = nullptr ;
@@ -323,17 +324,18 @@ bool ClangLoopSwapping::hasTrueOrAntiDependence(
323324 return false ;
324325}
325326
326- bool ClangLoopSwapping::isSwappingAvailable (const LoopVector &Loops) const {
327+ bool ClangLoopSwapping::isSwappingAvailable (
328+ std::pair<Loop *, Loop *> &Loops) const {
327329 auto HasLoopID = [this ](MDNode*& LoopID, int LoopN) {
328330 if (!LoopID || !(LoopID = mGetLoopID (LoopID))) {
329331 LLVM_DEBUG (dbgs () << " [LOOP SWAPPING]: ignore loop without ID (loop " <<
330- LoopN << " )." );
332+ LoopN << " ).\n " );
331333 return false ;
332334 }
333335 return true ;
334336 };
335- auto *LoopID0 = Loops[ 0 ] ->getLoopID ();
336- auto *LoopID1 = Loops[ 1 ] ->getLoopID ();
337+ auto *LoopID0 = Loops. first ->getLoopID ();
338+ auto *LoopID1 = Loops. second ->getLoopID ();
337339 if (!HasLoopID (LoopID0, 0 ))
338340 return false ;
339341 if (!HasLoopID (LoopID1, 1 ))
@@ -353,28 +355,49 @@ bool ClangLoopSwapping::isSwappingAvailable(const LoopVector &Loops) const {
353355}
354356
355357void ClangLoopSwapping::swapLoops (const LoopVisitor &Visitor) {
356- const auto &RangePairs = Visitor.getRangePairs ();
357- const auto &LoopPairs = Visitor.getLoopPairs ();
358- Rewriter &mRewriter = mTfmCtx ->getRewriter ();
359- for (size_t i = 0 ; i < RangePairs.size (); i++) {
360- const auto &Ranges = RangePairs[i];
361- const auto &Loops = LoopPairs[i];
362- if (Ranges.size () < 2 ) {
363- toDiag (mSrcMgr ->getDiagnostics (),
364- diag::warn_loop_swapping_missing_loop);
365- continue ;
366- }
367- if (Ranges.size () > 2 ) {
368- toDiag (mSrcMgr ->getDiagnostics (),
369- diag::warn_loop_swapping_redundant_loop);
370- }
371- if (isSwappingAvailable (Loops)) {
372- const auto &FirstRange = Ranges[0 ];
373- const auto &SecondRange = Ranges[1 ];
374- const auto &FirstLoop = mRewriter .getRewrittenText (FirstRange);
375- const auto &SecondLoop = mRewriter .getRewrittenText (SecondRange);
376- mRewriter .ReplaceText (FirstRange, SecondLoop);
377- mRewriter .ReplaceText (SecondRange, FirstLoop);
358+ Rewriter &Rewr = mTfmCtx ->getRewriter ();
359+ auto GetLoopEnd = [this , Rewr](const SourceRange &LoopRange)->SourceLocation {
360+ Token SemiTok;
361+ return (!getRawTokenAfter (LoopRange.getEnd (), *mSrcMgr ,
362+ Rewr.getLangOpts (), SemiTok) && SemiTok.is (tok::semi)) ?
363+ SemiTok.getLocation () : LoopRange.getEnd ();
364+ };
365+ auto &PragmaLevels = Visitor.getPragmaLevels ();
366+ auto &PragmaLoopsInfo = Visitor.getPragmaLoopsInfo ();
367+ for (auto it = PragmaLevels.rbegin (); it != PragmaLevels.rend (); it++) {
368+ for (auto &Pragma : *it) {
369+ auto PragmaItr = PragmaLoopsInfo.find (Pragma);
370+ assert (PragmaItr != PragmaLoopsInfo.end () &&
371+ " Map should contain all pragmas (as keys) from level vectors." );
372+ const auto &Loops = PragmaItr->second ;
373+ if (Loops.size () < 2 ) {
374+ toDiag (mSrcMgr ->getDiagnostics (),
375+ diag::warn_loop_swapping_missing_loop);
376+ continue ;
377+ }
378+ if (Loops.size () > 2 ) {
379+ toDiag (mSrcMgr ->getDiagnostics (),
380+ diag::warn_loop_swapping_redundant_loop);
381+ }
382+ auto Info0 = Loops[0 ];
383+ auto Info1 = Loops[1 ];
384+ auto Loop0 = Info0.first ;
385+ auto Loop1 = Info1.first ;
386+ auto LoopPair = std::make_pair (Loop0, Loop1);
387+ if (isSwappingAvailable (LoopPair)) {
388+ auto &Range0 = Info0.second ;
389+ auto &Range1 = Info1.second ;
390+ Range0.setEnd (GetLoopEnd (Range0));
391+ Range1.setEnd (GetLoopEnd (Range1));
392+ auto Range0End = Range0.getEnd ();
393+ auto Range1Begin = Range1.getBegin ();
394+ const auto &LoopText0 = Rewr.getRewrittenText (Range0);
395+ const auto &LoopText1 = Rewr.getRewrittenText (Range1);
396+ Rewr.RemoveText (Range0);
397+ Rewr.RemoveText (Range1);
398+ Rewr.InsertTextBefore (Range0End, LoopText1);
399+ Rewr.InsertTextAfter (Range1Begin, LoopText0);
400+ }
378401 }
379402 }
380403}
@@ -413,8 +436,8 @@ bool ClangLoopSwapping::runOnFunction(Function &F) {
413436 LoopVisitor Visitor (mTfmCtx ->getRewriter (), mLoopInfo , *ImportInfo);
414437 mSrcMgr = &mTfmCtx ->getRewriter ().getSourceMgr ();
415438 Visitor.TraverseDecl (FuncDecl);
416- if (Visitor.getLoopCount () == 0 ) {
417- LLVM_DEBUG (dbgs () << " [LOOP SWAPPING]: no loop found.\n " );
439+ if (Visitor.getMaxPragmaDepth () == 0 ) {
440+ LLVM_DEBUG (dbgs () << " [LOOP SWAPPING]: no pragma found.\n " );
418441 return false ;
419442 }
420443 Visitor.printLocations ();
0 commit comments