Skip to content

Commit ac3988b

Browse files
committed
Use single ASTVisitor
1 parent 73d5d64 commit ac3988b

File tree

1 file changed

+59
-102
lines changed

1 file changed

+59
-102
lines changed

lib/Transform/Clang/LoopDistribution.cpp

Lines changed: 59 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@ class ASTVisitor : public RecursiveASTVisitor<ASTVisitor> {
126126
.getMatcher();
127127
mGlobalOptions = &Pass.getAnalysis<GlobalOptionsImmutableWrapper>()
128128
.getOptions();
129+
mRewriter = &TransformationContext.getRewriter();
129130
mSourceManager = &TransformationContext.getRewriter().getSourceMgr();
130131
mASTContext = &TransformationContext.getContext();
131132
auto& SocketInfo = Pass.getAnalysis<AnalysisSocketImmutableWrapper>().get();
@@ -151,15 +152,8 @@ class ASTVisitor : public RecursiveASTVisitor<ASTVisitor> {
151152
}
152153

153154
[[maybe_unused]]
154-
bool TraverseStmt(Stmt *Statement) {
155-
if (!Statement) {
156-
return RecursiveASTVisitor::TraverseStmt(Statement);
157-
}
158-
auto *ForStatement = dyn_cast<ForStmt>(Statement);
159-
if (!ForStatement) {
160-
return RecursiveASTVisitor::TraverseStmt(Statement);
161-
}
162-
155+
bool TraverseForStmt(ForStmt *ForStatement) {
156+
DynTypedNode::create(*ForStatement).dump(dbgs(), *mASTContext);
163157
const auto LoopMatch = mLoopMatcher->find<AST>(ForStatement);
164158
if (LoopMatch == mLoopMatcher->end()) {
165159
return false;
@@ -182,12 +176,13 @@ class ASTVisitor : public RecursiveASTVisitor<ASTVisitor> {
182176
Split->dump();
183177
}
184178
);
185-
processSplits(Splits);
179+
180+
processSplits(ForStatement, Splits);
186181

187182
// TODO: Use this information.
188183
const auto PrevIsInsideLoop = mIsInsideLoop;
189184
mIsInsideLoop = true;
190-
const auto Result = RecursiveASTVisitor::TraverseStmt(ForStatement->getBody());
185+
const auto Result = TraverseStmt(ForStatement->getBody());
191186
mIsInsideLoop = PrevIsInsideLoop;
192187
return Result;
193188
}
@@ -241,11 +236,11 @@ class ASTVisitor : public RecursiveASTVisitor<ASTVisitor> {
241236
if (DIMemoryTraitItr->is<trait::Flow>()) {
242237
printDILocationSource(dwarf::DW_LANG_C, *DIMemory, dbgs());
243238
dbgs() << "Flow dependency\n";
244-
} if (DIMemoryTraitItr->is<trait::Anti>()) {
239+
}
240+
if (DIMemoryTraitItr->is<trait::Anti>()) {
245241
printDILocationSource(dwarf::DW_LANG_C, *DIMemory, dbgs());
246242
dbgs() << "Antiflow dependency\n";
247243
});
248-
249244
const auto *DINode = DIMemory->getAliasNode();
250245
for (const auto &BasicBlock : Loop->blocks()) {
251246
// Get all reads and writes of memory leading to dependencies
@@ -346,7 +341,17 @@ class ASTVisitor : public RecursiveASTVisitor<ASTVisitor> {
346341
return Splits;
347342
}
348343

349-
void processSplits(const SplitInstructionVector &Splits) const {
344+
void processSplits(const ForStmt *ForStatement,
345+
const SplitInstructionVector &Splits) const {
346+
const auto LoopHeaderSplitter = getLoopHeaderSplitter(ForStatement);
347+
if (!LoopHeaderSplitter.hasValue()) {
348+
dbgs() << "Couldn't get character data for ";
349+
ForStatement->dump();
350+
dbgs() << "\n";
351+
return;
352+
}
353+
354+
dbgs() << LoopHeaderSplitter.getValue() << "\n";
350355
for (auto *Split : Splits) {
351356
const auto &ExpressionMatcherItr = mExpressionMatcher->find<IR>(Split);
352357
if (ExpressionMatcherItr == mExpressionMatcher->end()) {
@@ -356,115 +361,66 @@ class ASTVisitor : public RecursiveASTVisitor<ASTVisitor> {
356361
}
357362
const auto &SplitStatement = ExpressionMatcherItr->get<AST>();
358363
SplitStatement.dump(dbgs(), *mASTContext);
364+
// TODO: Incorrect location
365+
mRewriter->InsertText(SplitStatement.getSourceRange().getEnd(),
366+
LoopHeaderSplitter.getValue(), true, true);
359367
}
360368
}
361369

362-
private:
363-
DFRegionInfo *mDFRegion;
364-
TargetLibraryInfo *mTargetLibrary;
365-
AliasTree *mAliasTree;
366-
DominatorTree *mDominatorTree;
367-
DIMemoryClientServerInfo *mServerDIMemory;
368-
SpanningTreeRelation<const DIAliasTree *> *mSpanningTreeRelation;
369-
const CanonicalLoopSet *mCanonicalLoop;
370-
const ClangExprMatcherPass::ExprMatcher *mExpressionMatcher;
371-
const LoopMatcherPass::LoopMatcher *mLoopMatcher;
372-
const GlobalOptions *mGlobalOptions;
373-
const SourceManager *mSourceManager;
374-
const ASTContext *mASTContext;
375-
DIAliasTree *mDIAliasTree;
376-
DIDependencInfo *mDIDependency;
377-
DependenceInfo *mDependence;
378-
std::function<ObjectID(ObjectID)> mGetServerLoopIdFunction;
379-
std::function<Instruction * (Instruction *)> mGetInstructionFunction;
380-
bool mIsInsideLoop = false;
381-
};
382-
383-
class CodeRewriter : public RecursiveASTVisitor<CodeRewriter> {
384-
public:
385-
CodeRewriter(FunctionPass &Pass, Function &Function,
386-
ClangTransformationContext &TransformationContext) {
387-
mRewriter = &TransformationContext.getRewriter();
388-
mSourceMgr = &mRewriter->getSourceMgr();
389-
}
390-
391-
[[maybe_unused]]
392-
bool TraverseStmt(Stmt *Statement) {
393-
if (!Statement) {
394-
return RecursiveASTVisitor::TraverseStmt(Statement);
395-
}
396-
397-
auto *ForStatement = dyn_cast<ForStmt>(Statement);
398-
if (ForStatement) {
399-
return TraverseForStmt(ForStatement);
400-
}
401-
402-
return RecursiveASTVisitor::TraverseStmt(Statement);
403-
}
404-
405-
[[maybe_unused]]
406-
bool TraverseForStmt(ForStmt *ForStatement) {
407-
dbgs() << "First time?\n";
408-
//ForStatement->dump();
409-
410-
mLoopHeaderSplitter = getLoopHeaderSplitter(ForStatement);
411-
dbgs() << mLoopHeaderSplitter << "\n";
412-
413-
// TODO: Use this information.
414-
const auto PrevIsInsideLoop = mIsInsideLoop;
415-
mIsInsideLoop = true;
416-
const auto Result =
417-
RecursiveASTVisitor::TraverseStmt(ForStatement->getBody());
418-
mIsInsideLoop = PrevIsInsideLoop;
419-
//return Result;
420-
return true;
421-
}
422-
423-
[[maybe_unused]]
424-
bool VisitStmt(Stmt *Statement) {
425-
if (!mIsInsideLoop) {
426-
return true;
370+
[[nodiscard]]
371+
Optional<std::string> getLoopHeaderSplitter(
372+
const ForStmt *ForStatement) const {
373+
auto LoopHeader = getCharacterData(
374+
ForStatement->getBeginLoc(), ForStatement->getBody()->getBeginLoc());
375+
if (!LoopHeader.hasValue()) {
376+
return None;
427377
}
428378

429-
/*mRewriter->InsertText(Statement->getEndLoc(),
430-
mLoopHeaderSplitter, true, true);*/
431-
return false;
432-
}
433-
434-
private:
435-
[[nodiscard]]
436-
std::string getLoopHeaderSplitter(ForStmt *ForStatement) const {
437379
std::string LoopHeaderSplitter;
438380
raw_string_ostream SplitterStream(LoopHeaderSplitter);
439381
SplitterStream << "}";
440-
SplitterStream << getCharacterData(ForStatement->getBeginLoc(),
441-
ForStatement->getBody()->getBeginLoc());
382+
SplitterStream << LoopHeader.getValue();
442383
SplitterStream << "{";
443384
return SplitterStream.str();
444385
}
445386

446387
[[nodiscard]]
447-
std::string getCharacterData(const SourceLocation BeginLoc,
448-
const SourceLocation EndLoc) const {
388+
Optional<std::string> getCharacterData(
389+
const SourceLocation BeginLoc, const SourceLocation EndLoc) const {
449390
bool Invalid;
450-
const auto BeginData = mSourceMgr->getCharacterData(BeginLoc, &Invalid);
391+
const auto BeginData = mSourceManager->getCharacterData(BeginLoc, &Invalid);
451392
if (Invalid) {
452-
throw std::exception("Couldn't get character data");
393+
return None;
453394
}
454395

455-
const auto EndData = mSourceMgr->getCharacterData(EndLoc, &Invalid);
396+
const auto EndData = mSourceManager->getCharacterData(EndLoc, &Invalid);
456397
if (Invalid) {
457-
throw std::exception("Couldn't get character data");
398+
return None;
458399
}
459400

460401
return std::string(BeginData, EndData);
461402
}
462-
403+
463404
private:
405+
DFRegionInfo *mDFRegion;
406+
TargetLibraryInfo *mTargetLibrary;
407+
AliasTree *mAliasTree;
408+
DominatorTree *mDominatorTree;
409+
DIMemoryClientServerInfo *mServerDIMemory;
410+
SpanningTreeRelation<const DIAliasTree *> *mSpanningTreeRelation;
411+
const CanonicalLoopSet *mCanonicalLoop;
412+
const ClangExprMatcherPass::ExprMatcher *mExpressionMatcher;
413+
const LoopMatcherPass::LoopMatcher *mLoopMatcher;
414+
const GlobalOptions *mGlobalOptions;
464415
Rewriter *mRewriter;
465-
SourceManager *mSourceMgr;
416+
const SourceManager *mSourceManager;
417+
const ASTContext *mASTContext;
418+
DIAliasTree *mDIAliasTree;
419+
DIDependencInfo *mDIDependency;
420+
DependenceInfo *mDependence;
421+
std::function<ObjectID(ObjectID)> mGetServerLoopIdFunction;
422+
std::function<Instruction * (Instruction *)> mGetInstructionFunction;
466423
bool mIsInsideLoop = false;
467-
std::string mLoopHeaderSplitter;
468424
};
469425
}
470426

@@ -475,19 +431,20 @@ bool LoopDistributionPass::runOnFunction(Function& Function) {
475431
if (!TransformationInfo) {
476432
return false;
477433
}
434+
478435
auto *TransformationContext = TransformationInfo->getContext(*Module);
479436
if (!TransformationContext || !TransformationContext->hasInstance()) {
480437
return false;
481438
}
439+
482440
auto *FunctionDecl =
483441
TransformationContext->getDeclForMangledName(Function.getName());
484442
if (!FunctionDecl) {
485443
return false;
486444
}
487-
//ASTVisitor LoopVisitor(*this, Function, *TransformationContext);
488-
//LoopVisitor.TraverseDecl(FunctionDecl);
489-
CodeRewriter Rewriter(*this, Function, *TransformationContext);
490-
Rewriter.TraverseDecl(FunctionDecl);
445+
446+
ASTVisitor LoopVisitor(*this, Function, *TransformationContext);
447+
LoopVisitor.TraverseDecl(FunctionDecl);
491448
return false;
492449
}
493450

0 commit comments

Comments
 (0)