Skip to content

Commit 2408863

Browse files
committed
[TSAR, Transform] Limit scopes of loops for swapping.
1 parent f6e3de2 commit 2408863

File tree

2 files changed

+33
-18
lines changed

2 files changed

+33
-18
lines changed

include/tsar/Support/DiagnosticKinds.td

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,8 @@ def warn_loop_swapping_diff_reduction: Warning<"unable to swap loops due to the
184184
def warn_loop_swapping_true_anti_dependence: Warning<"unable to swap loops due to the true or anti dependence">;
185185
def warn_loop_swapping_missing_loop: Warning<"not enough loops for swapping">;
186186
def warn_loop_swapping_redundant_loop: Warning<"too many loops for swapping, ignore redundant">;
187+
def warn_loop_swapping_no_loop_id: Warning<"cannot find loop ID to perform swapping">;
188+
187189
def error_loop_swapping_lost_loop: Error<"cannot match ForStmt with its IR">;
188190
def error_loop_swapping_expect_compound: Error<"expected compound statement after pragma">;
189-
def warn_loop_swapping_no_loop_id: Warning<"cannot find loop ID to perform swapping">;
191+
def error_loop_swapping_diff_scope: Error<"loops within a pragma must have the same scope">;

lib/Transform/Clang/LoopSwapping.cpp

Lines changed: 30 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -54,11 +54,19 @@ using namespace tsar;
5454

5555
namespace {
5656

57+
struct LoopRangeInfo {
58+
Loop *LoopPtr;
59+
SourceRange Range;
60+
CompoundStmt *CompStmtPtr;
61+
LoopRangeInfo() : LoopPtr(nullptr), CompStmtPtr(nullptr) {}
62+
LoopRangeInfo(llvm::Loop *L, const SourceRange &R, CompoundStmt *S):
63+
LoopPtr(L), Range(R), CompStmtPtr(S) {}
64+
};
65+
5766
/// This provides access to function-level analysis results on server.
5867
using ClangLoopSwappingProvider =
5968
FunctionPassAAProvider<DIEstimateMemoryPass, DIDependencyAnalysisPass>;
6069
using DIAliasTraitVector = std::vector<const DIAliasTrait *>;
61-
using LoopRangeInfo = std::pair<Loop *, SourceRange>;
6270
using LoopRangeList = SmallVector<LoopRangeInfo, 2>;
6371
using PragmaInfoList = SmallVector<std::pair<Stmt *, LoopRangeList>, 2>;
6472

@@ -74,7 +82,6 @@ class LoopVisitor : public RecursiveASTVisitor<LoopVisitor> {
7482
, mLangOpts(Rewr.getLangOpts())
7583
, mLoopInfo(LM)
7684
, mState(TraverseState::NONE)
77-
, mCurrentLevel(-1)
7885
{}
7986

8087
bool TraverseStmt(Stmt *S) {
@@ -104,7 +111,6 @@ class LoopVisitor : public RecursiveASTVisitor<LoopVisitor> {
104111
RemoveEmptyLine.RemoveLineIfEmpty = false;
105112
/*for (auto SR : ToRemove)
106113
mRewriter.RemoveText(SR, RemoveEmptyLine);*/
107-
mCurrentLevel++;
108114
mPragmaLoopsInfo.resize(mPragmaLoopsInfo.size() + 1);
109115
mPragmaLoopsInfo.back().first = S;
110116
mState = TraverseState::PRAGMA;
@@ -121,22 +127,30 @@ class LoopVisitor : public RecursiveASTVisitor<LoopVisitor> {
121127
}
122128

123129
bool TraverseCompoundStmt(CompoundStmt *S) {
130+
mCompStmtStack.push(S);
124131
if (mState == TraverseState::PRAGMA) {
125132
mState = TraverseState::OUTERFOR;
126133
auto Res = RecursiveASTVisitor::TraverseCompoundStmt(S);
127134
mState = TraverseState::NONE;
128-
mCurrentLevel--;
129135
return Res;
130136
}
131-
return RecursiveASTVisitor::TraverseCompoundStmt(S);
137+
auto Res = RecursiveASTVisitor::TraverseCompoundStmt(S);
138+
mCompStmtStack.pop();
139+
return Res;
132140
}
133141

134142
bool TraverseForStmt(ForStmt *S) {
135143
if (mState == TraverseState::OUTERFOR) {
136144
auto Match = mLoopInfo.find<AST>(S);
137145
if (Match != mLoopInfo.end()) {
138-
mPragmaLoopsInfo.back().second.push_back(
139-
std::make_pair(Match->get<IR>(), S->getSourceRange()));
146+
auto &LRL = mPragmaLoopsInfo.back().second;
147+
if (!LRL.empty() && LRL.back().CompStmtPtr != mCompStmtStack.top()) {
148+
toDiag(mSrcMgr.getDiagnostics(), S->getBeginLoc(),
149+
diag::error_loop_swapping_diff_scope);
150+
return false;
151+
}
152+
LRL.push_back(LoopRangeInfo(Match->get<IR>(), S->getSourceRange(),
153+
mCompStmtStack.top()));
140154
} else {
141155
toDiag(mSrcMgr.getDiagnostics(), S->getBeginLoc(),
142156
diag::error_loop_swapping_lost_loop);
@@ -164,8 +178,8 @@ class LoopVisitor : public RecursiveASTVisitor<LoopVisitor> {
164178
++It, ++N) {
165179
dbgs() << "\tPragma " << N << " (" << It->first <<"):\n";
166180
for (const auto &Info : It->second) {
167-
const auto LoopPtr = Info.first;
168-
const auto &Range = Info.second;
181+
const auto LoopPtr = Info.LoopPtr;
182+
const auto &Range = Info.Range;
169183
dbgs() << "\t\t[Range]\n";
170184
dbgs() << "\t\tBegin:" << Range.getBegin().printToString(mSrcMgr)
171185
<< "\n";
@@ -186,8 +200,7 @@ class LoopVisitor : public RecursiveASTVisitor<LoopVisitor> {
186200
const LoopMatcherPass::LoopMatcher &mLoopInfo;
187201
TraverseState mState;
188202
SmallVector<Stmt *, 1> mClauses;
189-
190-
int mCurrentLevel;
203+
std::stack<CompoundStmt *> mCompStmtStack;
191204
PragmaInfoList mPragmaLoopsInfo;
192205
};
193206

@@ -314,15 +327,15 @@ bool ClangLoopSwapping::hasTrueOrAntiDependence(
314327

315328
bool ClangLoopSwapping::isSwappingAvailable(
316329
const LoopRangeList &LRL, const Stmt *Pragma) const {
317-
auto *LoopID0 = mGetLoopID(LRL[0].first->getLoopID());
318-
auto *LoopID1 = mGetLoopID(LRL[1].first->getLoopID());
330+
auto *LoopID0 = mGetLoopID(LRL[0].LoopPtr->getLoopID());
331+
auto *LoopID1 = mGetLoopID(LRL[1].LoopPtr->getLoopID());
319332
if (!LoopID0) {
320-
toDiag(mSrcMgr->getDiagnostics(), LRL[0].second.getBegin(),
333+
toDiag(mSrcMgr->getDiagnostics(), LRL[0].Range.getBegin(),
321334
diag::warn_loop_swapping_no_loop_id);
322335
return false;
323336
}
324337
if (!LoopID1) {
325-
toDiag(mSrcMgr->getDiagnostics(), LRL[1].second.getBegin(),
338+
toDiag(mSrcMgr->getDiagnostics(), LRL[1].Range.getBegin(),
326339
diag::warn_loop_swapping_no_loop_id);
327340
return false;
328341
}
@@ -363,8 +376,8 @@ void ClangLoopSwapping::swapLoops(const LoopVisitor &Visitor) {
363376
diag::warn_loop_swapping_redundant_loop);
364377
}
365378
if (isSwappingAvailable(Loops, Pragma)) {
366-
auto Range0 = Loops[0].second;
367-
auto Range1 = Loops[1].second;
379+
auto Range0 = Loops[0].Range;
380+
auto Range1 = Loops[1].Range;
368381
Range0.setEnd(GetLoopEnd(Range0));
369382
Range1.setEnd(GetLoopEnd(Range1));
370383
auto Range0End = Range0.getEnd();

0 commit comments

Comments
 (0)