@@ -126,6 +126,7 @@ class ASTVisitor : public RecursiveASTVisitor<ASTVisitor> {
126126 .getMatcher ();
127127 mGlobalOptions = &Pass.getAnalysis <GlobalOptionsImmutableWrapper>()
128128 .getOptions ();
129+ mRewriter = &TransformationContext.getRewriter ();
129130 mSourceManager = &TransformationContext.getRewriter ().getSourceMgr ();
130131 mASTContext = &TransformationContext.getContext ();
131132 auto & SocketInfo = Pass.getAnalysis <AnalysisSocketImmutableWrapper>().get ();
@@ -151,15 +152,8 @@ class ASTVisitor : public RecursiveASTVisitor<ASTVisitor> {
151152 }
152153
153154 [[maybe_unused]]
154- bool TraverseStmt (Stmt *Statement) {
155- if (!Statement) {
156- return RecursiveASTVisitor::TraverseStmt (Statement);
157- }
158- auto *ForStatement = dyn_cast<ForStmt>(Statement);
159- if (!ForStatement) {
160- return RecursiveASTVisitor::TraverseStmt (Statement);
161- }
162-
155+ bool TraverseForStmt (ForStmt *ForStatement) {
156+ DynTypedNode::create (*ForStatement).dump (dbgs (), *mASTContext );
163157 const auto LoopMatch = mLoopMatcher ->find <AST>(ForStatement);
164158 if (LoopMatch == mLoopMatcher ->end ()) {
165159 return false ;
@@ -182,12 +176,13 @@ class ASTVisitor : public RecursiveASTVisitor<ASTVisitor> {
182176 Split->dump ();
183177 }
184178 );
185- processSplits (Splits);
179+
180+ processSplits (ForStatement, Splits);
186181
187182 // TODO: Use this information.
188183 const auto PrevIsInsideLoop = mIsInsideLoop ;
189184 mIsInsideLoop = true ;
190- const auto Result = RecursiveASTVisitor:: TraverseStmt (ForStatement->getBody ());
185+ const auto Result = TraverseStmt (ForStatement->getBody ());
191186 mIsInsideLoop = PrevIsInsideLoop;
192187 return Result;
193188 }
@@ -241,11 +236,11 @@ class ASTVisitor : public RecursiveASTVisitor<ASTVisitor> {
241236 if (DIMemoryTraitItr->is <trait::Flow>()) {
242237 printDILocationSource (dwarf::DW_LANG_C, *DIMemory, dbgs ());
243238 dbgs () << " Flow dependency\n " ;
244- } if (DIMemoryTraitItr->is <trait::Anti>()) {
239+ }
240+ if (DIMemoryTraitItr->is <trait::Anti>()) {
245241 printDILocationSource (dwarf::DW_LANG_C, *DIMemory, dbgs ());
246242 dbgs () << " Antiflow dependency\n " ;
247243 });
248-
249244 const auto *DINode = DIMemory->getAliasNode ();
250245 for (const auto &BasicBlock : Loop->blocks ()) {
251246 // Get all reads and writes of memory leading to dependencies
@@ -346,7 +341,17 @@ class ASTVisitor : public RecursiveASTVisitor<ASTVisitor> {
346341 return Splits;
347342 }
348343
349- void processSplits (const SplitInstructionVector &Splits) const {
344+ void processSplits (const ForStmt *ForStatement,
345+ const SplitInstructionVector &Splits) const {
346+ const auto LoopHeaderSplitter = getLoopHeaderSplitter (ForStatement);
347+ if (!LoopHeaderSplitter.hasValue ()) {
348+ dbgs () << " Couldn't get character data for " ;
349+ ForStatement->dump ();
350+ dbgs () << " \n " ;
351+ return ;
352+ }
353+
354+ dbgs () << LoopHeaderSplitter.getValue () << " \n " ;
350355 for (auto *Split : Splits) {
351356 const auto &ExpressionMatcherItr = mExpressionMatcher ->find <IR>(Split);
352357 if (ExpressionMatcherItr == mExpressionMatcher ->end ()) {
@@ -356,115 +361,66 @@ class ASTVisitor : public RecursiveASTVisitor<ASTVisitor> {
356361 }
357362 const auto &SplitStatement = ExpressionMatcherItr->get <AST>();
358363 SplitStatement.dump (dbgs (), *mASTContext );
364+ // TODO: Incorrect location
365+ mRewriter ->InsertText (SplitStatement.getSourceRange ().getEnd (),
366+ LoopHeaderSplitter.getValue (), true , true );
359367 }
360368 }
361369
362- private:
363- DFRegionInfo *mDFRegion ;
364- TargetLibraryInfo *mTargetLibrary ;
365- AliasTree *mAliasTree ;
366- DominatorTree *mDominatorTree ;
367- DIMemoryClientServerInfo *mServerDIMemory ;
368- SpanningTreeRelation<const DIAliasTree *> *mSpanningTreeRelation ;
369- const CanonicalLoopSet *mCanonicalLoop ;
370- const ClangExprMatcherPass::ExprMatcher *mExpressionMatcher ;
371- const LoopMatcherPass::LoopMatcher *mLoopMatcher ;
372- const GlobalOptions *mGlobalOptions ;
373- const SourceManager *mSourceManager ;
374- const ASTContext *mASTContext ;
375- DIAliasTree *mDIAliasTree ;
376- DIDependencInfo *mDIDependency ;
377- DependenceInfo *mDependence ;
378- std::function<ObjectID(ObjectID)> mGetServerLoopIdFunction ;
379- std::function<Instruction * (Instruction *)> mGetInstructionFunction ;
380- bool mIsInsideLoop = false ;
381- };
382-
383- class CodeRewriter : public RecursiveASTVisitor <CodeRewriter> {
384- public:
385- CodeRewriter (FunctionPass &Pass, Function &Function,
386- ClangTransformationContext &TransformationContext) {
387- mRewriter = &TransformationContext.getRewriter ();
388- mSourceMgr = &mRewriter ->getSourceMgr ();
389- }
390-
391- [[maybe_unused]]
392- bool TraverseStmt (Stmt *Statement) {
393- if (!Statement) {
394- return RecursiveASTVisitor::TraverseStmt (Statement);
395- }
396-
397- auto *ForStatement = dyn_cast<ForStmt>(Statement);
398- if (ForStatement) {
399- return TraverseForStmt (ForStatement);
400- }
401-
402- return RecursiveASTVisitor::TraverseStmt (Statement);
403- }
404-
405- [[maybe_unused]]
406- bool TraverseForStmt (ForStmt *ForStatement) {
407- dbgs () << " First time?\n " ;
408- // ForStatement->dump();
409-
410- mLoopHeaderSplitter = getLoopHeaderSplitter (ForStatement);
411- dbgs () << mLoopHeaderSplitter << " \n " ;
412-
413- // TODO: Use this information.
414- const auto PrevIsInsideLoop = mIsInsideLoop ;
415- mIsInsideLoop = true ;
416- const auto Result =
417- RecursiveASTVisitor::TraverseStmt (ForStatement->getBody ());
418- mIsInsideLoop = PrevIsInsideLoop;
419- // return Result;
420- return true ;
421- }
422-
423- [[maybe_unused]]
424- bool VisitStmt (Stmt *Statement) {
425- if (!mIsInsideLoop ) {
426- return true ;
370+ [[nodiscard]]
371+ Optional<std::string> getLoopHeaderSplitter (
372+ const ForStmt *ForStatement) const {
373+ auto LoopHeader = getCharacterData (
374+ ForStatement->getBeginLoc (), ForStatement->getBody ()->getBeginLoc ());
375+ if (!LoopHeader.hasValue ()) {
376+ return None;
427377 }
428378
429- /* mRewriter->InsertText(Statement->getEndLoc(),
430- mLoopHeaderSplitter, true, true);*/
431- return false ;
432- }
433-
434- private:
435- [[nodiscard]]
436- std::string getLoopHeaderSplitter (ForStmt *ForStatement) const {
437379 std::string LoopHeaderSplitter;
438380 raw_string_ostream SplitterStream (LoopHeaderSplitter);
439381 SplitterStream << " }" ;
440- SplitterStream << getCharacterData (ForStatement->getBeginLoc (),
441- ForStatement->getBody ()->getBeginLoc ());
382+ SplitterStream << LoopHeader.getValue ();
442383 SplitterStream << " {" ;
443384 return SplitterStream.str ();
444385 }
445386
446387 [[nodiscard]]
447- std::string getCharacterData (const SourceLocation BeginLoc,
448- const SourceLocation EndLoc) const {
388+ Optional< std::string> getCharacterData (
389+ const SourceLocation BeginLoc, const SourceLocation EndLoc) const {
449390 bool Invalid;
450- const auto BeginData = mSourceMgr ->getCharacterData (BeginLoc, &Invalid);
391+ const auto BeginData = mSourceManager ->getCharacterData (BeginLoc, &Invalid);
451392 if (Invalid) {
452- throw std::exception ( " Couldn't get character data " ) ;
393+ return None ;
453394 }
454395
455- const auto EndData = mSourceMgr ->getCharacterData (EndLoc, &Invalid);
396+ const auto EndData = mSourceManager ->getCharacterData (EndLoc, &Invalid);
456397 if (Invalid) {
457- throw std::exception ( " Couldn't get character data " ) ;
398+ return None ;
458399 }
459400
460401 return std::string (BeginData, EndData);
461402 }
462-
403+
463404private:
405+ DFRegionInfo *mDFRegion ;
406+ TargetLibraryInfo *mTargetLibrary ;
407+ AliasTree *mAliasTree ;
408+ DominatorTree *mDominatorTree ;
409+ DIMemoryClientServerInfo *mServerDIMemory ;
410+ SpanningTreeRelation<const DIAliasTree *> *mSpanningTreeRelation ;
411+ const CanonicalLoopSet *mCanonicalLoop ;
412+ const ClangExprMatcherPass::ExprMatcher *mExpressionMatcher ;
413+ const LoopMatcherPass::LoopMatcher *mLoopMatcher ;
414+ const GlobalOptions *mGlobalOptions ;
464415 Rewriter *mRewriter ;
465- SourceManager *mSourceMgr ;
416+ const SourceManager *mSourceManager ;
417+ const ASTContext *mASTContext ;
418+ DIAliasTree *mDIAliasTree ;
419+ DIDependencInfo *mDIDependency ;
420+ DependenceInfo *mDependence ;
421+ std::function<ObjectID(ObjectID)> mGetServerLoopIdFunction ;
422+ std::function<Instruction * (Instruction *)> mGetInstructionFunction ;
466423 bool mIsInsideLoop = false ;
467- std::string mLoopHeaderSplitter ;
468424};
469425}
470426
@@ -475,19 +431,20 @@ bool LoopDistributionPass::runOnFunction(Function& Function) {
475431 if (!TransformationInfo) {
476432 return false ;
477433 }
434+
478435 auto *TransformationContext = TransformationInfo->getContext (*Module);
479436 if (!TransformationContext || !TransformationContext->hasInstance ()) {
480437 return false ;
481438 }
439+
482440 auto *FunctionDecl =
483441 TransformationContext->getDeclForMangledName (Function.getName ());
484442 if (!FunctionDecl) {
485443 return false ;
486444 }
487- // ASTVisitor LoopVisitor(*this, Function, *TransformationContext);
488- // LoopVisitor.TraverseDecl(FunctionDecl);
489- CodeRewriter Rewriter (*this , Function, *TransformationContext);
490- Rewriter.TraverseDecl (FunctionDecl);
445+
446+ ASTVisitor LoopVisitor (*this , Function, *TransformationContext);
447+ LoopVisitor.TraverseDecl (FunctionDecl);
491448 return false ;
492449}
493450
0 commit comments