diff --git a/clang/include/clang/AST/StmtOpenMP.h b/clang/include/clang/AST/StmtOpenMP.h index a436676113921..d9f87f1e49b40 100644 --- a/clang/include/clang/AST/StmtOpenMP.h +++ b/clang/include/clang/AST/StmtOpenMP.h @@ -956,30 +956,46 @@ class OMPLoopBasedDirective : public OMPExecutableDirective { } }; +/// Common class of data shared between +/// OMPCanonicalLoopNestTransformationDirective and transformations over +/// canonical loop sequences. +class OMPLoopTransformationDirective { + /// Number of (top-level) generated loops. + /// This value is 1 for most transformations as they only map one loop nest + /// into another. + /// Some loop transformations (like a non-partial 'unroll') may not generate + /// a loop nest, so this would be 0. + /// Some loop transformations (like 'fuse' with looprange and 'split') may + /// generate more than one loop nest, so the value would be >= 1. + unsigned NumGeneratedTopLevelLoops = 1; + +protected: + void setNumGeneratedTopLevelLoops(unsigned N) { + NumGeneratedTopLevelLoops = N; + } + +public: + unsigned getNumGeneratedTopLevelLoops() const { + return NumGeneratedTopLevelLoops; + } +}; + /// The base class for all transformation directives of canonical loop nests. class OMPCanonicalLoopNestTransformationDirective - : public OMPLoopBasedDirective { + : public OMPLoopBasedDirective, + public OMPLoopTransformationDirective { friend class ASTStmtReader; - /// Number of loops generated by this loop transformation. - unsigned NumGeneratedLoops = 0; - protected: explicit OMPCanonicalLoopNestTransformationDirective( StmtClass SC, OpenMPDirectiveKind Kind, SourceLocation StartLoc, SourceLocation EndLoc, unsigned NumAssociatedLoops) : OMPLoopBasedDirective(SC, Kind, StartLoc, EndLoc, NumAssociatedLoops) {} - /// Set the number of loops generated by this loop transformation. - void setNumGeneratedLoops(unsigned Num) { NumGeneratedLoops = Num; } - public: /// Return the number of associated (consumed) loops. unsigned getNumAssociatedLoops() const { return getLoopsNumber(); } - /// Return the number of loops generated by this loop transformation. - unsigned getNumGeneratedLoops() const { return NumGeneratedLoops; } - /// Get the de-sugared statements after the loop transformation. /// /// Might be nullptr if either the directive generates no loops and is handled @@ -5560,9 +5576,7 @@ class OMPTileDirective final unsigned NumLoops) : OMPCanonicalLoopNestTransformationDirective( OMPTileDirectiveClass, llvm::omp::OMPD_tile, StartLoc, EndLoc, - NumLoops) { - setNumGeneratedLoops(2 * NumLoops); - } + NumLoops) {} void setPreInits(Stmt *PreInits) { Data->getChildren()[PreInitsOffset] = PreInits; @@ -5638,9 +5652,7 @@ class OMPStripeDirective final unsigned NumLoops) : OMPCanonicalLoopNestTransformationDirective( OMPStripeDirectiveClass, llvm::omp::OMPD_stripe, StartLoc, EndLoc, - NumLoops) { - setNumGeneratedLoops(2 * NumLoops); - } + NumLoops) {} void setPreInits(Stmt *PreInits) { Data->getChildren()[PreInitsOffset] = PreInits; @@ -5744,7 +5756,8 @@ class OMPUnrollDirective final static OMPUnrollDirective * Create(const ASTContext &C, SourceLocation StartLoc, SourceLocation EndLoc, ArrayRef Clauses, Stmt *AssociatedStmt, - unsigned NumGeneratedLoops, Stmt *TransformedStmt, Stmt *PreInits); + unsigned NumGeneratedTopLevelLoops, Stmt *TransformedStmt, + Stmt *PreInits); /// Build an empty '#pragma omp unroll' AST node for deserialization. /// @@ -5794,9 +5807,7 @@ class OMPReverseDirective final unsigned NumLoops) : OMPCanonicalLoopNestTransformationDirective( OMPReverseDirectiveClass, llvm::omp::OMPD_reverse, StartLoc, EndLoc, - NumLoops) { - setNumGeneratedLoops(NumLoops); - } + NumLoops) {} void setPreInits(Stmt *PreInits) { Data->getChildren()[PreInitsOffset] = PreInits; @@ -5867,9 +5878,7 @@ class OMPInterchangeDirective final SourceLocation EndLoc, unsigned NumLoops) : OMPCanonicalLoopNestTransformationDirective( OMPInterchangeDirectiveClass, llvm::omp::OMPD_interchange, StartLoc, - EndLoc, NumLoops) { - setNumGeneratedLoops(NumLoops); - } + EndLoc, NumLoops) {} void setPreInits(Stmt *PreInits) { Data->getChildren()[PreInitsOffset] = PreInits; diff --git a/clang/lib/AST/StmtOpenMP.cpp b/clang/lib/AST/StmtOpenMP.cpp index 36ecaf6489ef0..1f6586f95a9f8 100644 --- a/clang/lib/AST/StmtOpenMP.cpp +++ b/clang/lib/AST/StmtOpenMP.cpp @@ -139,13 +139,14 @@ bool OMPLoopBasedDirective::doForAllLoops( Stmt *TransformedStmt = Dir->getTransformedStmt(); if (!TransformedStmt) { - unsigned NumGeneratedLoops = Dir->getNumGeneratedLoops(); - if (NumGeneratedLoops == 0) { + unsigned NumGeneratedTopLevelLoops = + Dir->getNumGeneratedTopLevelLoops(); + if (NumGeneratedTopLevelLoops == 0) { // May happen if the loop transformation does not result in a // generated loop (such as full unrolling). break; } - if (NumGeneratedLoops > 0) { + if (NumGeneratedTopLevelLoops > 0) { // The loop transformation construct has generated loops, but these // may not have been generated yet due to being in a dependent // context. @@ -447,16 +448,16 @@ OMPStripeDirective *OMPStripeDirective::CreateEmpty(const ASTContext &C, SourceLocation(), SourceLocation(), NumLoops); } -OMPUnrollDirective * -OMPUnrollDirective::Create(const ASTContext &C, SourceLocation StartLoc, - SourceLocation EndLoc, ArrayRef Clauses, - Stmt *AssociatedStmt, unsigned NumGeneratedLoops, - Stmt *TransformedStmt, Stmt *PreInits) { - assert(NumGeneratedLoops <= 1 && "Unrolling generates at most one loop"); +OMPUnrollDirective *OMPUnrollDirective::Create( + const ASTContext &C, SourceLocation StartLoc, SourceLocation EndLoc, + ArrayRef Clauses, Stmt *AssociatedStmt, + unsigned NumGeneratedTopLevelLoops, Stmt *TransformedStmt, Stmt *PreInits) { + assert(NumGeneratedTopLevelLoops <= 1 && + "Unrolling generates at most one loop"); auto *Dir = createDirective( C, Clauses, AssociatedStmt, TransformedStmtOffset + 1, StartLoc, EndLoc); - Dir->setNumGeneratedLoops(NumGeneratedLoops); + Dir->setNumGeneratedTopLevelLoops(NumGeneratedTopLevelLoops); Dir->setTransformedStmt(TransformedStmt); Dir->setPreInits(PreInits); return Dir; diff --git a/clang/lib/Sema/SemaOpenMP.cpp b/clang/lib/Sema/SemaOpenMP.cpp index 63a56a6583efc..60f0317020c59 100644 --- a/clang/lib/Sema/SemaOpenMP.cpp +++ b/clang/lib/Sema/SemaOpenMP.cpp @@ -14919,12 +14919,13 @@ StmtResult SemaOpenMP::ActOnOpenMPUnrollDirective(ArrayRef Clauses, Body, OriginalInits)) return StmtError(); - unsigned NumGeneratedLoops = PartialClause ? 1 : 0; + unsigned NumGeneratedTopLevelLoops = PartialClause ? 1 : 0; // Delay unrolling to when template is completely instantiated. if (SemaRef.CurContext->isDependentContext()) return OMPUnrollDirective::Create(Context, StartLoc, EndLoc, Clauses, AStmt, - NumGeneratedLoops, nullptr, nullptr); + NumGeneratedTopLevelLoops, nullptr, + nullptr); assert(LoopHelpers.size() == NumLoops && "Expecting a single-dimensional loop iteration space"); @@ -14947,9 +14948,10 @@ StmtResult SemaOpenMP::ActOnOpenMPUnrollDirective(ArrayRef Clauses, // The generated loop may only be passed to other loop-associated directive // when a partial clause is specified. Without the requirement it is // sufficient to generate loop unroll metadata at code-generation. - if (NumGeneratedLoops == 0) + if (NumGeneratedTopLevelLoops == 0) return OMPUnrollDirective::Create(Context, StartLoc, EndLoc, Clauses, AStmt, - NumGeneratedLoops, nullptr, nullptr); + NumGeneratedTopLevelLoops, nullptr, + nullptr); // Otherwise, we need to provide a de-sugared/transformed AST that can be // associated with another loop directive. @@ -15164,7 +15166,7 @@ StmtResult SemaOpenMP::ActOnOpenMPUnrollDirective(ArrayRef Clauses, LoopHelper.Init->getBeginLoc(), LoopHelper.Inc->getEndLoc()); return OMPUnrollDirective::Create(Context, StartLoc, EndLoc, Clauses, AStmt, - NumGeneratedLoops, OuterFor, + NumGeneratedTopLevelLoops, OuterFor, buildPreInits(Context, PreInits)); } diff --git a/clang/lib/Serialization/ASTReaderStmt.cpp b/clang/lib/Serialization/ASTReaderStmt.cpp index 7ec8e450fbaca..213c2c2148f64 100644 --- a/clang/lib/Serialization/ASTReaderStmt.cpp +++ b/clang/lib/Serialization/ASTReaderStmt.cpp @@ -2450,7 +2450,7 @@ void ASTStmtReader::VisitOMPSimdDirective(OMPSimdDirective *D) { void ASTStmtReader::VisitOMPCanonicalLoopNestTransformationDirective( OMPCanonicalLoopNestTransformationDirective *D) { VisitOMPLoopBasedDirective(D); - D->setNumGeneratedLoops(Record.readUInt32()); + D->setNumGeneratedTopLevelLoops(Record.readUInt32()); } void ASTStmtReader::VisitOMPTileDirective(OMPTileDirective *D) { diff --git a/clang/lib/Serialization/ASTWriterStmt.cpp b/clang/lib/Serialization/ASTWriterStmt.cpp index 07a5cde47a9a8..21c04ddbc2c7a 100644 --- a/clang/lib/Serialization/ASTWriterStmt.cpp +++ b/clang/lib/Serialization/ASTWriterStmt.cpp @@ -2459,7 +2459,7 @@ void ASTStmtWriter::VisitOMPSimdDirective(OMPSimdDirective *D) { void ASTStmtWriter::VisitOMPCanonicalLoopNestTransformationDirective( OMPCanonicalLoopNestTransformationDirective *D) { VisitOMPLoopBasedDirective(D); - Record.writeUInt32(D->getNumGeneratedLoops()); + Record.writeUInt32(D->getNumGeneratedTopLevelLoops()); } void ASTStmtWriter::VisitOMPTileDirective(OMPTileDirective *D) {