Skip to content

Commit c2b683d

Browse files
ro-ironlieb
authored andcommitted
Reland: [OpenMP][clang] 6.0: num_threads strict (part 3: codegen) (llvm#155839)
OpenMP 6.0 12.1.2 specifies the behavior of the strict modifier for the num_threads clause on parallel directives, along with the message and severity clauses. This commit implements necessary codegen changes.
1 parent bb27715 commit c2b683d

29 files changed

+17061
-1051
lines changed

clang/include/clang/AST/OpenMPClause.h

Lines changed: 20 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1865,62 +1865,43 @@ class OMPSeverityClause final : public OMPClause {
18651865
/// \endcode
18661866
/// In this example directive '#pragma omp error' has simple
18671867
/// 'message' clause with user error message of "GNU compiler required.".
1868-
class OMPMessageClause final : public OMPClause {
1868+
class OMPMessageClause final
1869+
: public OMPOneStmtClause<llvm::omp::OMPC_message, OMPClause>,
1870+
public OMPClauseWithPreInit {
18691871
friend class OMPClauseReader;
18701872

1871-
/// Location of '('
1872-
SourceLocation LParenLoc;
1873-
1874-
// Expression of the 'message' clause.
1875-
Stmt *MessageString = nullptr;
1876-
18771873
/// Set message string of the clause.
1878-
void setMessageString(Expr *MS) { MessageString = MS; }
1879-
1880-
/// Sets the location of '('.
1881-
void setLParenLoc(SourceLocation Loc) { LParenLoc = Loc; }
1874+
void setMessageString(Expr *MS) { setStmt(MS); }
18821875

18831876
public:
18841877
/// Build 'message' clause with message string argument
18851878
///
18861879
/// \param MS Argument of the clause (message string).
1880+
/// \param HelperMS Helper statement for the construct.
1881+
/// \param CaptureRegion Innermost OpenMP region where expressions in this
1882+
/// clause must be captured.
18871883
/// \param StartLoc Starting location of the clause.
18881884
/// \param LParenLoc Location of '('.
18891885
/// \param EndLoc Ending location of the clause.
1890-
OMPMessageClause(Expr *MS, SourceLocation StartLoc, SourceLocation LParenLoc,
1886+
OMPMessageClause(Expr *MS, Stmt *HelperMS, OpenMPDirectiveKind CaptureRegion,
1887+
SourceLocation StartLoc, SourceLocation LParenLoc,
18911888
SourceLocation EndLoc)
1892-
: OMPClause(llvm::omp::OMPC_message, StartLoc, EndLoc),
1893-
LParenLoc(LParenLoc), MessageString(MS) {}
1894-
1895-
/// Build an empty clause.
1896-
OMPMessageClause()
1897-
: OMPClause(llvm::omp::OMPC_message, SourceLocation(), SourceLocation()) {
1889+
: OMPOneStmtClause(MS, StartLoc, LParenLoc, EndLoc),
1890+
OMPClauseWithPreInit(this) {
1891+
setPreInitStmt(HelperMS, CaptureRegion);
18981892
}
18991893

1900-
/// Returns the locaiton of '('.
1901-
SourceLocation getLParenLoc() const { return LParenLoc; }
1894+
/// Build an empty clause.
1895+
OMPMessageClause() : OMPOneStmtClause(), OMPClauseWithPreInit(this) {}
19021896

19031897
/// Returns message string of the clause.
1904-
Expr *getMessageString() const { return cast_or_null<Expr>(MessageString); }
1905-
1906-
child_range children() {
1907-
return child_range(&MessageString, &MessageString + 1);
1908-
}
1909-
1910-
const_child_range children() const {
1911-
return const_child_range(&MessageString, &MessageString + 1);
1912-
}
1913-
1914-
child_range used_children() {
1915-
return child_range(child_iterator(), child_iterator());
1916-
}
1917-
1918-
const_child_range used_children() const {
1919-
return const_child_range(const_child_iterator(), const_child_iterator());
1920-
}
1898+
Expr *getMessageString() const { return getStmtAs<Expr>(); }
19211899

1922-
static bool classof(const OMPClause *T) {
1923-
return T->getClauseKind() == llvm::omp::OMPC_message;
1900+
/// Try to evaluate the message string at compile time.
1901+
std::optional<std::string> tryEvaluateString(ASTContext &Ctx) const {
1902+
if (Expr *MessageExpr = getMessageString())
1903+
return MessageExpr->tryEvaluateString(Ctx);
1904+
return std::nullopt;
19241905
}
19251906
};
19261907

clang/include/clang/Basic/DiagnosticParseKinds.td

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1506,8 +1506,8 @@ def err_omp_unexpected_directive : Error<
15061506
"unexpected OpenMP directive %select{|'#pragma omp %1'}0">;
15071507
def err_omp_expected_punc : Error<
15081508
"expected ',' or ')' in '%0' %select{clause|directive}1">;
1509-
def warn_clause_expected_string : Warning<
1510-
"expected string literal in 'clause %0' - ignoring">, InGroup<IgnoredPragmas>;
1509+
def warn_clause_expected_string: Warning<
1510+
"expected string %select{|literal }1in 'clause %0' - ignoring">, InGroup<IgnoredPragmas>;
15111511
def err_omp_unexpected_clause : Error<
15121512
"unexpected OpenMP clause '%0' in directive '#pragma omp %1'">;
15131513
def err_omp_unexpected_clause_extension_only : Error<

clang/lib/AST/OpenMPClause.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,8 @@ const OMPClauseWithPreInit *OMPClauseWithPreInit::get(const OMPClause *C) {
104104
return static_cast<const OMPFilterClause *>(C);
105105
case OMPC_ompx_dyn_cgroup_mem:
106106
return static_cast<const OMPXDynCGroupMemClause *>(C);
107+
case OMPC_message:
108+
return static_cast<const OMPMessageClause *>(C);
107109
case OMPC_default:
108110
case OMPC_proc_bind:
109111
case OMPC_safelen:
@@ -158,7 +160,6 @@ const OMPClauseWithPreInit *OMPClauseWithPreInit::get(const OMPClause *C) {
158160
case OMPC_self_maps:
159161
case OMPC_at:
160162
case OMPC_severity:
161-
case OMPC_message:
162163
case OMPC_device_type:
163164
case OMPC_match:
164165
case OMPC_nontemporal:
@@ -1963,8 +1964,10 @@ void OMPClausePrinter::VisitOMPSeverityClause(OMPSeverityClause *Node) {
19631964
}
19641965

19651966
void OMPClausePrinter::VisitOMPMessageClause(OMPMessageClause *Node) {
1966-
OS << "message(\""
1967-
<< cast<StringLiteral>(Node->getMessageString())->getString() << "\")";
1967+
OS << "message(";
1968+
if (Expr *E = Node->getMessageString())
1969+
E->printPretty(OS, nullptr, Policy);
1970+
OS << ")";
19681971
}
19691972

19701973
void OMPClausePrinter::VisitOMPScheduleClause(OMPScheduleClause *Node) {

clang/lib/CodeGen/CGOpenMPRuntime.cpp

Lines changed: 61 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1869,11 +1869,11 @@ void CGOpenMPRuntime::emitIfClause(CodeGenFunction &CGF, const Expr *Cond,
18691869
CGF.EmitBlock(ContBlock, /*IsFinished=*/true);
18701870
}
18711871

1872-
void CGOpenMPRuntime::emitParallelCall(CodeGenFunction &CGF, SourceLocation Loc,
1873-
llvm::Function *OutlinedFn,
1874-
ArrayRef<llvm::Value *> CapturedVars,
1875-
const Expr *IfCond,
1876-
llvm::Value *NumThreads) {
1872+
void CGOpenMPRuntime::emitParallelCall(
1873+
CodeGenFunction &CGF, SourceLocation Loc, llvm::Function *OutlinedFn,
1874+
ArrayRef<llvm::Value *> CapturedVars, const Expr *IfCond,
1875+
llvm::Value *NumThreads, OpenMPNumThreadsClauseModifier NumThreadsModifier,
1876+
OpenMPSeverityClauseKind Severity, const Expr *Message) {
18771877
if (!CGF.HaveInsertPoint())
18781878
return;
18791879
llvm::Value *RTLoc = emitUpdateLocation(CGF, Loc);
@@ -2396,9 +2396,8 @@ void CGOpenMPRuntime::emitBarrierCall(CodeGenFunction &CGF, SourceLocation Loc,
23962396

23972397
void CGOpenMPRuntime::emitErrorCall(CodeGenFunction &CGF, SourceLocation Loc,
23982398
Expr *ME, bool IsFatal) {
2399-
llvm::Value *MVL =
2400-
ME ? CGF.EmitStringLiteralLValue(cast<StringLiteral>(ME)).getPointer(CGF)
2401-
: llvm::ConstantPointerNull::get(CGF.VoidPtrTy);
2399+
llvm::Value *MVL = ME ? CGF.EmitScalarExpr(ME)
2400+
: llvm::ConstantPointerNull::get(CGF.VoidPtrTy);
24022401
// Build call void __kmpc_error(ident_t *loc, int severity, const char
24032402
// *message)
24042403
llvm::Value *Args[] = {
@@ -2746,18 +2745,54 @@ llvm::Value *CGOpenMPRuntime::emitForNext(CodeGenFunction &CGF,
27462745
CGF.getContext().BoolTy, Loc);
27472746
}
27482747

2749-
void CGOpenMPRuntime::emitNumThreadsClause(CodeGenFunction &CGF,
2750-
llvm::Value *NumThreads,
2751-
SourceLocation Loc) {
2748+
llvm::Value *CGOpenMPRuntime::emitMessageClause(CodeGenFunction &CGF,
2749+
const Expr *Message) {
2750+
if (!Message)
2751+
return llvm::ConstantPointerNull::get(CGF.VoidPtrTy);
2752+
return CGF.EmitScalarExpr(Message);
2753+
}
2754+
2755+
llvm::Value *
2756+
CGOpenMPRuntime::emitMessageClause(CodeGenFunction &CGF,
2757+
const OMPMessageClause *MessageClause) {
2758+
return emitMessageClause(
2759+
CGF, MessageClause ? MessageClause->getMessageString() : nullptr);
2760+
}
2761+
2762+
llvm::Value *
2763+
CGOpenMPRuntime::emitSeverityClause(OpenMPSeverityClauseKind Severity) {
2764+
// OpenMP 6.0, 10.4: "If no severity clause is specified then the effect is
2765+
// as if sev-level is fatal."
2766+
return llvm::ConstantInt::get(CGM.Int32Ty,
2767+
Severity == OMPC_SEVERITY_warning ? 1 : 2);
2768+
}
2769+
2770+
llvm::Value *
2771+
CGOpenMPRuntime::emitSeverityClause(const OMPSeverityClause *SeverityClause) {
2772+
return emitSeverityClause(SeverityClause ? SeverityClause->getSeverityKind()
2773+
: OMPC_SEVERITY_unknown);
2774+
}
2775+
2776+
void CGOpenMPRuntime::emitNumThreadsClause(
2777+
CodeGenFunction &CGF, llvm::Value *NumThreads, SourceLocation Loc,
2778+
OpenMPNumThreadsClauseModifier Modifier, OpenMPSeverityClauseKind Severity,
2779+
const Expr *Message) {
27522780
if (!CGF.HaveInsertPoint())
27532781
return;
2782+
llvm::SmallVector<llvm::Value *, 4> Args(
2783+
{emitUpdateLocation(CGF, Loc), getThreadID(CGF, Loc),
2784+
CGF.Builder.CreateIntCast(NumThreads, CGF.Int32Ty, /*isSigned*/ true)});
27542785
// Build call __kmpc_push_num_threads(&loc, global_tid, num_threads)
2755-
llvm::Value *Args[] = {
2756-
emitUpdateLocation(CGF, Loc), getThreadID(CGF, Loc),
2757-
CGF.Builder.CreateIntCast(NumThreads, CGF.Int32Ty, /*isSigned*/ true)};
2758-
CGF.EmitRuntimeCall(OMPBuilder.getOrCreateRuntimeFunction(
2759-
CGM.getModule(), OMPRTL___kmpc_push_num_threads),
2760-
Args);
2786+
// or __kmpc_push_num_threads_strict(&loc, global_tid, num_threads, severity,
2787+
// messsage) if strict modifier is used.
2788+
RuntimeFunction FnID = OMPRTL___kmpc_push_num_threads;
2789+
if (Modifier == OMPC_NUMTHREADS_strict) {
2790+
FnID = OMPRTL___kmpc_push_num_threads_strict;
2791+
Args.push_back(emitSeverityClause(Severity));
2792+
Args.push_back(emitMessageClause(CGF, Message));
2793+
}
2794+
CGF.EmitRuntimeCall(
2795+
OMPBuilder.getOrCreateRuntimeFunction(CGM.getModule(), FnID), Args);
27612796
}
27622797

27632798
void CGOpenMPRuntime::emitProcBindClause(CodeGenFunction &CGF,
@@ -12552,12 +12587,11 @@ llvm::Function *CGOpenMPSIMDRuntime::emitTaskOutlinedFunction(
1255212587
llvm_unreachable("Not supported in SIMD-only mode");
1255312588
}
1255412589

12555-
void CGOpenMPSIMDRuntime::emitParallelCall(CodeGenFunction &CGF,
12556-
SourceLocation Loc,
12557-
llvm::Function *OutlinedFn,
12558-
ArrayRef<llvm::Value *> CapturedVars,
12559-
const Expr *IfCond,
12560-
llvm::Value *NumThreads) {
12590+
void CGOpenMPSIMDRuntime::emitParallelCall(
12591+
CodeGenFunction &CGF, SourceLocation Loc, llvm::Function *OutlinedFn,
12592+
ArrayRef<llvm::Value *> CapturedVars, const Expr *IfCond,
12593+
llvm::Value *NumThreads, OpenMPNumThreadsClauseModifier NumThreadsModifier,
12594+
OpenMPSeverityClauseKind Severity, const Expr *Message) {
1256112595
llvm_unreachable("Not supported in SIMD-only mode");
1256212596
}
1256312597

@@ -12661,9 +12695,10 @@ llvm::Value *CGOpenMPSIMDRuntime::emitForNext(CodeGenFunction &CGF,
1266112695
llvm_unreachable("Not supported in SIMD-only mode");
1266212696
}
1266312697

12664-
void CGOpenMPSIMDRuntime::emitNumThreadsClause(CodeGenFunction &CGF,
12665-
llvm::Value *NumThreads,
12666-
SourceLocation Loc) {
12698+
void CGOpenMPSIMDRuntime::emitNumThreadsClause(
12699+
CodeGenFunction &CGF, llvm::Value *NumThreads, SourceLocation Loc,
12700+
OpenMPNumThreadsClauseModifier Modifier, OpenMPSeverityClauseKind Severity,
12701+
const Expr *Message) {
1266712702
llvm_unreachable("Not supported in SIMD-only mode");
1266812703
}
1266912704

clang/lib/CodeGen/CGOpenMPRuntime.h

Lines changed: 53 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -778,11 +778,22 @@ class CGOpenMPRuntime {
778778
/// specified, nullptr otherwise.
779779
/// \param NumThreads The value corresponding to the num_threads clause, if
780780
/// any, or nullptr.
781+
/// \param NumThreadsModifier The modifier of the num_threads clause, if
782+
/// any, ignored otherwise.
783+
/// \param Severity The severity corresponding to the num_threads clause, if
784+
/// any, ignored otherwise.
785+
/// \param Message The message string corresponding to the num_threads clause,
786+
/// if any, or nullptr.
781787
///
782-
virtual void emitParallelCall(CodeGenFunction &CGF, SourceLocation Loc,
783-
llvm::Function *OutlinedFn,
784-
ArrayRef<llvm::Value *> CapturedVars,
785-
const Expr *IfCond, llvm::Value *NumThreads);
788+
virtual void
789+
emitParallelCall(CodeGenFunction &CGF, SourceLocation Loc,
790+
llvm::Function *OutlinedFn,
791+
ArrayRef<llvm::Value *> CapturedVars, const Expr *IfCond,
792+
llvm::Value *NumThreads,
793+
OpenMPNumThreadsClauseModifier NumThreadsModifier =
794+
OMPC_NUMTHREADS_unknown,
795+
OpenMPSeverityClauseKind Severity = OMPC_SEVERITY_fatal,
796+
const Expr *Message = nullptr);
786797

787798
/// Emits a critical region.
788799
/// \param CriticalName Name of the critical region.
@@ -1050,13 +1061,28 @@ class CGOpenMPRuntime {
10501061
Address IL, Address LB,
10511062
Address UB, Address ST);
10521063

1064+
virtual llvm::Value *emitMessageClause(CodeGenFunction &CGF,
1065+
const Expr *Message);
1066+
virtual llvm::Value *emitMessageClause(CodeGenFunction &CGF,
1067+
const OMPMessageClause *MessageClause);
1068+
1069+
virtual llvm::Value *emitSeverityClause(OpenMPSeverityClauseKind Severity);
1070+
virtual llvm::Value *
1071+
emitSeverityClause(const OMPSeverityClause *SeverityClause);
1072+
10531073
/// Emits call to void __kmpc_push_num_threads(ident_t *loc, kmp_int32
10541074
/// global_tid, kmp_int32 num_threads) to generate code for 'num_threads'
10551075
/// clause.
1076+
/// If the modifier 'strict' is given:
1077+
/// Emits call to void __kmpc_push_num_threads_strict(ident_t *loc, kmp_int32
1078+
/// global_tid, kmp_int32 num_threads, int severity, const char *message) to
1079+
/// generate code for 'num_threads' clause with 'strict' modifier.
10561080
/// \param NumThreads An integer value of threads.
1057-
virtual void emitNumThreadsClause(CodeGenFunction &CGF,
1058-
llvm::Value *NumThreads,
1059-
SourceLocation Loc);
1081+
virtual void emitNumThreadsClause(
1082+
CodeGenFunction &CGF, llvm::Value *NumThreads, SourceLocation Loc,
1083+
OpenMPNumThreadsClauseModifier Modifier = OMPC_NUMTHREADS_unknown,
1084+
OpenMPSeverityClauseKind Severity = OMPC_SEVERITY_fatal,
1085+
const Expr *Message = nullptr);
10601086

10611087
/// Emit call to void __kmpc_push_proc_bind(ident_t *loc, kmp_int32
10621088
/// global_tid, int proc_bind) to generate code for 'proc_bind' clause.
@@ -1774,11 +1800,21 @@ class CGOpenMPSIMDRuntime final : public CGOpenMPRuntime {
17741800
/// specified, nullptr otherwise.
17751801
/// \param NumThreads The value corresponding to the num_threads clause, if
17761802
/// any, or nullptr.
1803+
/// \param NumThreadsModifier The modifier of the num_threads clause, if
1804+
/// any, ignored otherwise.
1805+
/// \param Severity The severity corresponding to the num_threads clause, if
1806+
/// any, ignored otherwise.
1807+
/// \param Message The message string corresponding to the num_threads clause,
1808+
/// if any, or nullptr.
17771809
///
17781810
void emitParallelCall(CodeGenFunction &CGF, SourceLocation Loc,
17791811
llvm::Function *OutlinedFn,
17801812
ArrayRef<llvm::Value *> CapturedVars,
1781-
const Expr *IfCond, llvm::Value *NumThreads) override;
1813+
const Expr *IfCond, llvm::Value *NumThreads,
1814+
OpenMPNumThreadsClauseModifier NumThreadsModifier =
1815+
OMPC_NUMTHREADS_unknown,
1816+
OpenMPSeverityClauseKind Severity = OMPC_SEVERITY_fatal,
1817+
const Expr *Message = nullptr) override;
17821818

17831819
/// Emits a critical region.
17841820
/// \param CriticalName Name of the critical region.
@@ -1949,9 +1985,16 @@ class CGOpenMPSIMDRuntime final : public CGOpenMPRuntime {
19491985
/// Emits call to void __kmpc_push_num_threads(ident_t *loc, kmp_int32
19501986
/// global_tid, kmp_int32 num_threads) to generate code for 'num_threads'
19511987
/// clause.
1988+
/// If the modifier 'strict' is given:
1989+
/// Emits call to void __kmpc_push_num_threads_strict(ident_t *loc, kmp_int32
1990+
/// global_tid, kmp_int32 num_threads, int severity, const char *message) to
1991+
/// generate code for 'num_threads' clause with 'strict' modifier.
19521992
/// \param NumThreads An integer value of threads.
1953-
void emitNumThreadsClause(CodeGenFunction &CGF, llvm::Value *NumThreads,
1954-
SourceLocation Loc) override;
1993+
void emitNumThreadsClause(
1994+
CodeGenFunction &CGF, llvm::Value *NumThreads, SourceLocation Loc,
1995+
OpenMPNumThreadsClauseModifier Modifier = OMPC_NUMTHREADS_unknown,
1996+
OpenMPSeverityClauseKind Severity = OMPC_SEVERITY_fatal,
1997+
const Expr *Message = nullptr) override;
19551998

19561999
/// Emit call to void __kmpc_push_proc_bind(ident_t *loc, kmp_int32
19572000
/// global_tid, int proc_bind) to generate code for 'proc_bind' clause.

0 commit comments

Comments
 (0)