@@ -54,11 +54,19 @@ using namespace tsar;
5454
5555namespace {
5656
57+ struct LoopRangeInfo {
58+ Loop *LoopPtr;
59+ SourceRange Range;
60+ CompoundStmt *CompStmtPtr;
61+ LoopRangeInfo () : LoopPtr(nullptr ), CompStmtPtr(nullptr ) {}
62+ LoopRangeInfo (llvm::Loop *L, const SourceRange &R, CompoundStmt *S):
63+ LoopPtr (L), Range(R), CompStmtPtr(S) {}
64+ };
65+
5766// / This provides access to function-level analysis results on server.
5867using ClangLoopSwappingProvider =
5968 FunctionPassAAProvider<DIEstimateMemoryPass, DIDependencyAnalysisPass>;
6069using DIAliasTraitVector = std::vector<const DIAliasTrait *>;
61- using LoopRangeInfo = std::pair<Loop *, SourceRange>;
6270using LoopRangeList = SmallVector<LoopRangeInfo, 2 >;
6371using PragmaInfoList = SmallVector<std::pair<Stmt *, LoopRangeList>, 2 >;
6472
@@ -74,7 +82,6 @@ class LoopVisitor : public RecursiveASTVisitor<LoopVisitor> {
7482 , mLangOpts (Rewr.getLangOpts())
7583 , mLoopInfo (LM)
7684 , mState (TraverseState::NONE)
77- , mCurrentLevel (-1 )
7885 {}
7986
8087 bool TraverseStmt (Stmt *S) {
@@ -104,7 +111,6 @@ class LoopVisitor : public RecursiveASTVisitor<LoopVisitor> {
104111 RemoveEmptyLine.RemoveLineIfEmpty = false ;
105112 /* for (auto SR : ToRemove)
106113 mRewriter.RemoveText(SR, RemoveEmptyLine);*/
107- mCurrentLevel ++;
108114 mPragmaLoopsInfo .resize (mPragmaLoopsInfo .size () + 1 );
109115 mPragmaLoopsInfo .back ().first = S;
110116 mState = TraverseState::PRAGMA;
@@ -121,22 +127,30 @@ class LoopVisitor : public RecursiveASTVisitor<LoopVisitor> {
121127 }
122128
123129 bool TraverseCompoundStmt (CompoundStmt *S) {
130+ mCompStmtStack .push (S);
124131 if (mState == TraverseState::PRAGMA) {
125132 mState = TraverseState::OUTERFOR;
126133 auto Res = RecursiveASTVisitor::TraverseCompoundStmt (S);
127134 mState = TraverseState::NONE;
128- mCurrentLevel --;
129135 return Res;
130136 }
131- return RecursiveASTVisitor::TraverseCompoundStmt (S);
137+ auto Res = RecursiveASTVisitor::TraverseCompoundStmt (S);
138+ mCompStmtStack .pop ();
139+ return Res;
132140 }
133141
134142 bool TraverseForStmt (ForStmt *S) {
135143 if (mState == TraverseState::OUTERFOR) {
136144 auto Match = mLoopInfo .find <AST>(S);
137145 if (Match != mLoopInfo .end ()) {
138- mPragmaLoopsInfo .back ().second .push_back (
139- std::make_pair (Match->get <IR>(), S->getSourceRange ()));
146+ auto &LRL = mPragmaLoopsInfo .back ().second ;
147+ if (!LRL.empty () && LRL.back ().CompStmtPtr != mCompStmtStack .top ()) {
148+ toDiag (mSrcMgr .getDiagnostics (), S->getBeginLoc (),
149+ diag::error_loop_swapping_diff_scope);
150+ return false ;
151+ }
152+ LRL.push_back (LoopRangeInfo (Match->get <IR>(), S->getSourceRange (),
153+ mCompStmtStack .top ()));
140154 } else {
141155 toDiag (mSrcMgr .getDiagnostics (), S->getBeginLoc (),
142156 diag::error_loop_swapping_lost_loop);
@@ -164,8 +178,8 @@ class LoopVisitor : public RecursiveASTVisitor<LoopVisitor> {
164178 ++It, ++N) {
165179 dbgs () << " \t Pragma " << N << " (" << It->first <<" ):\n " ;
166180 for (const auto &Info : It->second ) {
167- const auto LoopPtr = Info.first ;
168- const auto &Range = Info.second ;
181+ const auto LoopPtr = Info.LoopPtr ;
182+ const auto &Range = Info.Range ;
169183 dbgs () << " \t\t [Range]\n " ;
170184 dbgs () << " \t\t Begin:" << Range.getBegin ().printToString (mSrcMgr )
171185 << " \n " ;
@@ -186,8 +200,7 @@ class LoopVisitor : public RecursiveASTVisitor<LoopVisitor> {
186200 const LoopMatcherPass::LoopMatcher &mLoopInfo ;
187201 TraverseState mState ;
188202 SmallVector<Stmt *, 1 > mClauses ;
189-
190- int mCurrentLevel ;
203+ std::stack<CompoundStmt *> mCompStmtStack ;
191204 PragmaInfoList mPragmaLoopsInfo ;
192205};
193206
@@ -314,15 +327,15 @@ bool ClangLoopSwapping::hasTrueOrAntiDependence(
314327
315328bool ClangLoopSwapping::isSwappingAvailable (
316329 const LoopRangeList &LRL, const Stmt *Pragma) const {
317- auto *LoopID0 = mGetLoopID (LRL[0 ].first ->getLoopID ());
318- auto *LoopID1 = mGetLoopID (LRL[1 ].first ->getLoopID ());
330+ auto *LoopID0 = mGetLoopID (LRL[0 ].LoopPtr ->getLoopID ());
331+ auto *LoopID1 = mGetLoopID (LRL[1 ].LoopPtr ->getLoopID ());
319332 if (!LoopID0) {
320- toDiag (mSrcMgr ->getDiagnostics (), LRL[0 ].second .getBegin (),
333+ toDiag (mSrcMgr ->getDiagnostics (), LRL[0 ].Range .getBegin (),
321334 diag::warn_loop_swapping_no_loop_id);
322335 return false ;
323336 }
324337 if (!LoopID1) {
325- toDiag (mSrcMgr ->getDiagnostics (), LRL[1 ].second .getBegin (),
338+ toDiag (mSrcMgr ->getDiagnostics (), LRL[1 ].Range .getBegin (),
326339 diag::warn_loop_swapping_no_loop_id);
327340 return false ;
328341 }
@@ -363,8 +376,8 @@ void ClangLoopSwapping::swapLoops(const LoopVisitor &Visitor) {
363376 diag::warn_loop_swapping_redundant_loop);
364377 }
365378 if (isSwappingAvailable (Loops, Pragma)) {
366- auto Range0 = Loops[0 ].second ;
367- auto Range1 = Loops[1 ].second ;
379+ auto Range0 = Loops[0 ].Range ;
380+ auto Range1 = Loops[1 ].Range ;
368381 Range0.setEnd (GetLoopEnd (Range0));
369382 Range1.setEnd (GetLoopEnd (Range1));
370383 auto Range0End = Range0.getEnd ();
0 commit comments