Skip to content

Commit c03f3eb

Browse files
committed
[TSAR, Transform] Change traversal of pragma loops.
1 parent dc4fead commit c03f3eb

File tree

2 files changed

+44
-43
lines changed

2 files changed

+44
-43
lines changed

include/tsar/Support/DiagnosticKinds.td

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ def warn_loop_swapping_true_anti_dependence: Warning<"unable to swap loops due t
185185
def warn_loop_swapping_missing_loop: Warning<"not enough loops for swapping">;
186186
def warn_loop_swapping_redundant_loop: Warning<"too many loops for swapping, ignore redundant">;
187187
def warn_loop_swapping_no_loop_id: Warning<"cannot find loop ID to perform swapping">;
188-
189-
def error_loop_swapping_lost_loop: Error<"cannot match ForStmt with its IR">;
188+
def warn_loop_swapping_lost_loop: Warning<"cannot match ForStmt with its IR">;
189+
def warn_loop_swapping_redundant_stmt: Warning<"pragma should only contain loops or other pragma">;
190190
def error_loop_swapping_expect_compound: Error<"expected compound statement after pragma">;
191-
def error_loop_swapping_redundant_stmt: Error<"pragma should only contain loops or other pragma">;
191+

lib/Transform/Clang/LoopSwapping.cpp

Lines changed: 41 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,9 @@
2929
#include "tsar/Analysis/Memory/DIDependencyAnalysis.h"
3030
#include "tsar/Analysis/Memory/DIEstimateMemory.h"
3131
#include "tsar/Analysis/Memory/MemoryTraitUtils.h"
32-
#include "tsar/Core/TransformationContext.h"
3332
#include "tsar/Core/Query.h"
3433
#include "tsar/Frontend/Clang/Pragma.h"
34+
#include "tsar/Frontend/Clang/TransformationContext.h"
3535
#include "tsar/Support/Clang/Diagnostic.h"
3636
#include "tsar/Support/Clang/Utils.h"
3737
#include "tsar/Support/Clang/SourceLocationTraverse.h"
@@ -42,7 +42,6 @@
4242
#include <clang/AST/RecursiveASTVisitor.h>
4343
#include <llvm/ADT/DenseMap.h>
4444
#include <llvm/Analysis/LoopInfo.h>
45-
#include <vector>
4645
#include <stack>
4746

4847
using namespace llvm;
@@ -57,10 +56,10 @@ namespace {
5756
/// This provides access to function-level analysis results on server.
5857
using ClangLoopSwappingProvider =
5958
FunctionPassAAProvider<DIEstimateMemoryPass, DIDependencyAnalysisPass>;
60-
using DIAliasTraitVector = std::vector<const DIAliasTrait *>;
59+
using DIAliasTraitList = SmallVector<const DIAliasTrait *, 8>;
6160
using LoopRangeInfo = std::pair<Loop *, SourceRange>;
6261
using LoopRangeList = SmallVector<LoopRangeInfo, 2>;
63-
using PragmaInfoList = SmallVector<std::pair<Stmt *, LoopRangeList>, 2>;
62+
using PragmaInfoList = DenseMap<Stmt *, LoopRangeList>;
6463

6564
class LoopVisitor : public RecursiveASTVisitor<LoopVisitor> {
6665
private:
@@ -103,8 +102,8 @@ class LoopVisitor : public RecursiveASTVisitor<LoopVisitor> {
103102
RemoveEmptyLine.RemoveLineIfEmpty = false;
104103
/*for (auto SR : ToRemove)
105104
mRewriter.RemoveText(SR, RemoveEmptyLine);*/
106-
mPragmaLoopsInfo.resize(mPragmaLoopsInfo.size() + 1);
107-
mPragmaLoopsInfo.back().first = S;
105+
mPragmaLoopsInfo.insert(std::make_pair(S, LoopRangeList()));
106+
mPragmaStack.push(S);
108107
mState = TraverseState::PRAGMA;
109108
}
110109
return true;
@@ -117,7 +116,7 @@ class LoopVisitor : public RecursiveASTVisitor<LoopVisitor> {
117116
}
118117
if (mState == TraverseState::OUTERFOR && !dyn_cast<ForStmt>(S)) {
119118
toDiag(mSrcMgr.getDiagnostics(), S->getBeginLoc(),
120-
diag::error_loop_swapping_redundant_stmt);
119+
diag::warn_loop_swapping_redundant_stmt);
121120
return false;
122121
}
123122
return RecursiveASTVisitor::TraverseStmt(S);
@@ -127,22 +126,22 @@ class LoopVisitor : public RecursiveASTVisitor<LoopVisitor> {
127126
if (mState == TraverseState::PRAGMA) {
128127
mState = TraverseState::OUTERFOR;
129128
auto Res = RecursiveASTVisitor::TraverseCompoundStmt(S);
130-
mState = TraverseState::NONE;
129+
mPragmaStack.pop();
130+
mState = mPragmaStack.empty() ? TraverseState::NONE : TraverseState::OUTERFOR;
131131
return Res;
132132
}
133-
auto Res = RecursiveASTVisitor::TraverseCompoundStmt(S);
134-
return Res;
133+
return RecursiveASTVisitor::TraverseCompoundStmt(S);
135134
}
136135

137136
bool TraverseForStmt(ForStmt *S) {
138137
if (mState == TraverseState::OUTERFOR) {
139138
auto Match = mLoopInfo.find<AST>(S);
140139
if (Match != mLoopInfo.end()) {
141-
auto &LRL = mPragmaLoopsInfo.back().second;
140+
auto &LRL = mPragmaLoopsInfo[mPragmaStack.top()];
142141
LRL.push_back(std::make_pair(Match->get<IR>(), S->getSourceRange()));
143142
} else {
144143
toDiag(mSrcMgr.getDiagnostics(), S->getBeginLoc(),
145-
diag::error_loop_swapping_lost_loop);
144+
diag::warn_loop_swapping_lost_loop);
146145
}
147146
mState = TraverseState::INNERFOR;
148147
auto Res = RecursiveASTVisitor::TraverseForStmt(S);
@@ -165,17 +164,17 @@ class LoopVisitor : public RecursiveASTVisitor<LoopVisitor> {
165164
int N = 0;
166165
for (auto It = mPragmaLoopsInfo.begin(); It != mPragmaLoopsInfo.end();
167166
++It, ++N) {
168-
dbgs() << "\tPragma " << N << " (" << It->first <<"):\n";
167+
dbgs() << "Pragma " << N << " (" << It->first <<"):\n";
169168
for (const auto &Info : It->second) {
170169
const auto LoopPtr = Info.first;
171170
const auto &Range = Info.second;
172-
dbgs() << "\t\t[Range]\n";
173-
dbgs() << "\t\tBegin:" << Range.getBegin().printToString(mSrcMgr)
171+
dbgs() << "\t[Range]\n";
172+
dbgs() << "\tBegin:" << Range.getBegin().printToString(mSrcMgr)
174173
<< "\n";
175-
dbgs() << "\t\tEnd:" << Range.getEnd().printToString(mSrcMgr) <<"\n";
176-
dbgs() << "\t\t\n\t\t[Loop]\n";
174+
dbgs() << "\tEnd:" << Range.getEnd().printToString(mSrcMgr) <<"\n";
175+
dbgs() << "\t\n\t\t[Loop]\n";
177176
const auto &LoopText = mRewriter.getRewrittenText(Range);
178-
dbgs() << "\t\t" << LoopText << "\n\n";
177+
dbgs() << "\t" << LoopText << "\n\n";
179178
}
180179
}
181180
}
@@ -189,7 +188,8 @@ class LoopVisitor : public RecursiveASTVisitor<LoopVisitor> {
189188
const LoopMatcherPass::LoopMatcher &mLoopInfo;
190189
TraverseState mState;
191190
SmallVector<Stmt *, 1> mClauses;
192-
PragmaInfoList mPragmaLoopsInfo;
191+
PragmaInfoList mPragmaLoopsInfo;
192+
std::stack<Stmt *> mPragmaStack;
193193
};
194194

195195
class ClangLoopSwapping : public FunctionPass, private bcl::Uncopyable {
@@ -204,12 +204,12 @@ class ClangLoopSwapping : public FunctionPass, private bcl::Uncopyable {
204204

205205
private:
206206
void swapLoops(const LoopVisitor &Visitor);
207-
DIAliasTraitVector getLoopTraits(MDNode *LoopID) const;
207+
DIAliasTraitList getLoopTraits(MDNode *LoopID) const;
208208
bool isSwappingAvailable(const LoopRangeList &LRL, const Stmt *Pragma) const;
209-
bool hasSameReductionKind(const DIAliasTraitVector &TV0,
210-
const DIAliasTraitVector &TV1) const;
211-
bool hasTrueOrAntiDependence(const DIAliasTraitVector &TV0,
212-
const DIAliasTraitVector &TV1) const;
209+
bool hasSameReductionKind(const DIAliasTraitList &TV0,
210+
const DIAliasTraitList &TV1) const;
211+
bool hasTrueOrAntiDependence(const DIAliasTraitList &TV0,
212+
const DIAliasTraitList &TV1) const;
213213

214214
Function *mFunction = nullptr;
215215
TransformationContext *mTfmCtx = nullptr;
@@ -244,14 +244,14 @@ class ClangLoopSwappingInfo final : public PassGroupInfo {
244244

245245
char ClangLoopSwapping::ID = 0;
246246

247-
DIAliasTraitVector ClangLoopSwapping::getLoopTraits(MDNode *LoopID) const {
247+
DIAliasTraitList ClangLoopSwapping::getLoopTraits(MDNode *LoopID) const {
248248
auto DepItr = mDIDepInfo->find(LoopID);
249249
assert(DepItr != mDIDepInfo->end() && "Loop must be analyzed!");
250250
auto &DIDepSet = DepItr->get<DIDependenceSet>();
251251
DenseSet<const DIAliasNode *> Coverage;
252252
accessCoverage<bcl::SimpleInserter>(DIDepSet, *mDIAT, Coverage,
253253
mGlobalOpts->IgnoreRedundantMemory);
254-
DIAliasTraitVector Traits;
254+
DIAliasTraitList Traits;
255255
for (auto &TS : DIDepSet) {
256256
if (!Coverage.count(TS.getNode()))
257257
continue;
@@ -261,20 +261,17 @@ DIAliasTraitVector ClangLoopSwapping::getLoopTraits(MDNode *LoopID) const {
261261
}
262262

263263
bool ClangLoopSwapping::hasSameReductionKind(
264-
const DIAliasTraitVector &TV0, const DIAliasTraitVector &TV1) const {
264+
const DIAliasTraitList &TV0, const DIAliasTraitList &TV1) const {
265265
for (auto &TS0: TV0) {
266-
auto *Node0 = TS0->getNode();
267-
MemoryDescriptor Dptr0 = *TS0;
268-
if (!Dptr0.is<trait::Reduction>())
266+
if (!TS0->is<trait::Reduction>())
269267
continue;
268+
auto *Node0 = TS0->getNode();
270269
for (auto &TS1: TV1) {
271270
auto *Node1 = TS1->getNode();
272-
MemoryDescriptor Dptr1 = *TS1;
273-
if (Node0 == Node1 && Dptr1.is<trait::Reduction>()) {
271+
if (Node0 == Node1 && TS1->is<trait::Reduction>()) {
274272
LLVM_DEBUG(dbgs() << "[LOOP SWAPPING]: Same nodes with reduction.\n");
275-
auto I0 = TS0->begin(), I1 = TS1->begin();
276-
auto *Red0 = (**I0).get<trait::Reduction>();
277-
auto *Red1 = (**I1).get<trait::Reduction>();
273+
auto *Red0 = (**TS0->begin()).get<trait::Reduction>();
274+
auto *Red1 = (**TS1->begin()).get<trait::Reduction>();
278275
if (!Red0 || !Red1) {
279276
LLVM_DEBUG(dbgs() << "[LOOP SWAPPING]: Unknown Reduction.\n");
280277
return false;
@@ -294,7 +291,7 @@ bool ClangLoopSwapping::hasSameReductionKind(
294291
}
295292

296293
bool ClangLoopSwapping::hasTrueOrAntiDependence(
297-
const DIAliasTraitVector &TV0, const DIAliasTraitVector &TV1) const {
294+
const DIAliasTraitList &TV0, const DIAliasTraitList &TV1) const {
298295
SpanningTreeRelation<DIAliasTree *> STR(mDIAT);
299296
for (auto &TS0: TV0) {
300297
for (auto &TS1: TV1) {
@@ -315,8 +312,11 @@ bool ClangLoopSwapping::hasTrueOrAntiDependence(
315312

316313
bool ClangLoopSwapping::isSwappingAvailable(
317314
const LoopRangeList &LRL, const Stmt *Pragma) const {
318-
auto *LoopID0 = mGetLoopID(LRL[0].first->getLoopID());
319-
auto *LoopID1 = mGetLoopID(LRL[1].first->getLoopID());
315+
auto ClientLoopID0 = LRL[0].first->getLoopID();
316+
auto ClientLoopID1 = LRL[1].first->getLoopID();
317+
assert(ClientLoopID0 && ClientLoopID1 && "LoopID must not be null!");
318+
auto *LoopID0 = mGetLoopID(ClientLoopID0);
319+
auto *LoopID1 = mGetLoopID(ClientLoopID1);
320320
if (!LoopID0) {
321321
toDiag(mSrcMgr->getDiagnostics(), LRL[0].second.getBegin(),
322322
diag::warn_loop_swapping_no_loop_id);
@@ -383,10 +383,11 @@ void ClangLoopSwapping::swapLoops(const LoopVisitor &Visitor) {
383383
bool ClangLoopSwapping::runOnFunction(Function &F) {
384384
mFunction = &F;
385385
auto *M = F.getParent();
386-
mTfmCtx = getAnalysis<TransformationEnginePass>().getContext(*M);
386+
auto &TfmInfo = getAnalysis<TransformationEnginePass>();
387+
mTfmCtx = TfmInfo ? TfmInfo->getContext(*M) : nullptr;
387388
if (!mTfmCtx || !mTfmCtx->hasInstance()) {
388389
M->getContext().emitError("can not transform sources"
389-
": transformation context is not available");
390+
": transformation context is not available");
390391
return false;
391392
}
392393
auto FuncDecl = mTfmCtx->getDeclForMangledName(F.getName());

0 commit comments

Comments
 (0)