@@ -469,5 +469,128 @@ class OpenACCHostDataConstruct final
469469 return const_cast <OpenACCHostDataConstruct *>(this )->getStructuredBlock ();
470470 }
471471};
472+
473+ // This class represents a 'wait' construct, which has some expressions plus a
474+ // clause list.
475+ class OpenACCWaitConstruct final
476+ : public OpenACCConstructStmt,
477+ private llvm::TrailingObjects<OpenACCWaitConstruct, Expr *,
478+ OpenACCClause *> {
479+ // FIXME: We should be storing a `const OpenACCClause *` to be consistent with
480+ // the rest of the constructs, but TrailingObjects doesn't allow for mixing
481+ // constness in its implementation of `getTrailingObjects`.
482+
483+ friend TrailingObjects;
484+ friend class ASTStmtWriter ;
485+ friend class ASTStmtReader ;
486+ // Locations of the left and right parens of the 'wait-argument'
487+ // expression-list.
488+ SourceLocation LParenLoc, RParenLoc;
489+ // Location of the 'queues' keyword, if present.
490+ SourceLocation QueuesLoc;
491+
492+ // Number of the expressions being represented. Index '0' is always the
493+ // 'devnum' expression, even if it not present.
494+ unsigned NumExprs = 0 ;
495+
496+ OpenACCWaitConstruct (unsigned NumExprs, unsigned NumClauses)
497+ : OpenACCConstructStmt(OpenACCWaitConstructClass,
498+ OpenACCDirectiveKind::Wait, SourceLocation{},
499+ SourceLocation{}, SourceLocation{}),
500+ NumExprs (NumExprs) {
501+ assert (NumExprs >= 1 &&
502+ " NumExprs should always be >= 1 because the 'devnum' "
503+ " expr is represented by a null if necessary" );
504+ std::uninitialized_value_construct (getExprPtr (),
505+ getExprPtr () + NumExprs);
506+ std::uninitialized_value_construct (getTrailingObjects<OpenACCClause *>(),
507+ getTrailingObjects<OpenACCClause *>() +
508+ NumClauses);
509+ setClauseList (MutableArrayRef (const_cast <const OpenACCClause **>(
510+ getTrailingObjects<OpenACCClause *>()),
511+ NumClauses));
512+ }
513+
514+ OpenACCWaitConstruct (SourceLocation Start, SourceLocation DirectiveLoc,
515+ SourceLocation LParenLoc, Expr *DevNumExpr,
516+ SourceLocation QueuesLoc, ArrayRef<Expr *> QueueIdExprs,
517+ SourceLocation RParenLoc, SourceLocation End,
518+ ArrayRef<const OpenACCClause *> Clauses)
519+ : OpenACCConstructStmt(OpenACCWaitConstructClass,
520+ OpenACCDirectiveKind::Wait, Start, DirectiveLoc,
521+ End),
522+ LParenLoc (LParenLoc), RParenLoc(RParenLoc), QueuesLoc(QueuesLoc),
523+ NumExprs(QueueIdExprs.size() + 1) {
524+ assert (NumExprs >= 1 &&
525+ " NumExprs should always be >= 1 because the 'devnum' "
526+ " expr is represented by a null if necessary" );
527+
528+ std::uninitialized_copy (&DevNumExpr, &DevNumExpr + 1 ,
529+ getExprPtr ());
530+ std::uninitialized_copy (QueueIdExprs.begin (), QueueIdExprs.end (),
531+ getExprPtr () + 1 );
532+
533+ std::uninitialized_copy (const_cast <OpenACCClause **>(Clauses.begin ()),
534+ const_cast <OpenACCClause **>(Clauses.end ()),
535+ getTrailingObjects<OpenACCClause *>());
536+ setClauseList (MutableArrayRef (const_cast <const OpenACCClause **>(
537+ getTrailingObjects<OpenACCClause *>()),
538+ Clauses.size ()));
539+ }
540+
541+ size_t numTrailingObjects (OverloadToken<Expr *>) const { return NumExprs; }
542+ size_t numTrailingObjects (OverloadToken<const OpenACCClause *>) const {
543+ return clauses ().size ();
544+ }
545+
546+ Expr **getExprPtr () const {
547+ return const_cast <Expr**>(getTrailingObjects<Expr *>());
548+ }
549+
550+ llvm::ArrayRef<Expr *> getExprs () const {
551+ return llvm::ArrayRef<Expr *>(getExprPtr (), NumExprs);
552+ }
553+
554+ llvm::ArrayRef<Expr *> getExprs () {
555+ return llvm::ArrayRef<Expr *>(getExprPtr (), NumExprs);
556+ }
557+
558+ public:
559+ static bool classof (const Stmt *T) {
560+ return T->getStmtClass () == OpenACCWaitConstructClass;
561+ }
562+
563+ static OpenACCWaitConstruct *
564+ CreateEmpty (const ASTContext &C, unsigned NumExprs, unsigned NumClauses);
565+
566+ static OpenACCWaitConstruct *
567+ Create (const ASTContext &C, SourceLocation Start, SourceLocation DirectiveLoc,
568+ SourceLocation LParenLoc, Expr *DevNumExpr, SourceLocation QueuesLoc,
569+ ArrayRef<Expr *> QueueIdExprs, SourceLocation RParenLoc,
570+ SourceLocation End, ArrayRef<const OpenACCClause *> Clauses);
571+
572+ SourceLocation getLParenLoc () const { return LParenLoc; }
573+ SourceLocation getRParenLoc () const { return RParenLoc; }
574+ bool hasQueuesTag () const { return !QueuesLoc.isInvalid (); }
575+ SourceLocation getQueuesLoc () const { return QueuesLoc; }
576+
577+ bool hasDevNumExpr () const { return getExprs ()[0 ]; }
578+ Expr *getDevNumExpr () const { return getExprs ()[0 ]; }
579+ llvm::ArrayRef<Expr *> getQueueIdExprs () { return getExprs ().drop_front (); }
580+ llvm::ArrayRef<Expr *> getQueueIdExprs () const {
581+ return getExprs ().drop_front ();
582+ }
583+
584+ child_range children () {
585+ Stmt **Begin = reinterpret_cast <Stmt **>(getExprPtr ());
586+ return child_range (Begin, Begin + NumExprs);
587+ }
588+
589+ const_child_range children () const {
590+ Stmt *const *Begin =
591+ reinterpret_cast <Stmt *const *>(getExprPtr ());
592+ return const_child_range (Begin, Begin + NumExprs);
593+ }
594+ };
472595} // namespace clang
473596#endif // LLVM_CLANG_AST_STMTOPENACC_H
0 commit comments