Skip to content

Commit dfbc24d

Browse files
committed
[Clang][OpenMP] Recover strict modifier for num_threads
1 parent 0e92beb commit dfbc24d

File tree

3 files changed

+29
-24
lines changed

3 files changed

+29
-24
lines changed

clang/lib/CodeGen/CGOpenMPRuntimeGPU.cpp

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -924,13 +924,6 @@ void CGOpenMPRuntimeGPU::emitNumThreadsClause(
924924
OpenMPNumThreadsClauseModifier Modifier, OpenMPSeverityClauseKind Severity,
925925
SourceLocation SeverityLoc, const Expr *Message,
926926
SourceLocation MessageLoc) {
927-
if (Modifier == OMPC_NUMTHREADS_strict) {
928-
CGM.getDiags().Report(Loc,
929-
diag::warn_omp_gpu_unsupported_modifier_for_clause)
930-
<< "strict" << getOpenMPClauseName(OMPC_num_threads);
931-
return;
932-
}
933-
934927
// Nothing to do.
935928
}
936929

@@ -1238,9 +1231,9 @@ void CGOpenMPRuntimeGPU::emitParallelCall(
12381231
if (!CGF.HaveInsertPoint())
12391232
return;
12401233

1241-
auto &&ParallelGen = [this, Loc, OutlinedFn, CapturedVars, IfCond,
1242-
NumThreads](CodeGenFunction &CGF,
1243-
PrePostActionTy &Action) {
1234+
auto &&ParallelGen = [this, Loc, OutlinedFn, CapturedVars, IfCond, NumThreads,
1235+
NumThreadsModifier](CodeGenFunction &CGF,
1236+
PrePostActionTy &Action) {
12441237
CGBuilderTy &Bld = CGF.Builder;
12451238
llvm::Value *NumThreadsVal = NumThreads;
12461239
llvm::Function *WFn = WrapperFunctionsMap[OutlinedFn];
@@ -1291,8 +1284,8 @@ void CGOpenMPRuntimeGPU::emitParallelCall(
12911284
else
12921285
NumThreadsVal = Bld.CreateZExtOrTrunc(NumThreadsVal, CGF.Int32Ty);
12931286

1294-
// No strict prescriptiveness for the number of threads.
1295-
llvm::Value *StrictNumThreadsVal = llvm::ConstantInt::get(CGF.Int32Ty, 0);
1287+
// Forward whether the strict modifier is specified.
1288+
llvm::Value *StrictNumThreadsVal = llvm::ConstantInt::get(CGM.Int32Ty, NumThreadsModifier == OMPC_NUMTHREADS_strict);
12961289

12971290
assert(IfCondVal && "Expected a value");
12981291
llvm::Value *RTLoc = emitUpdateLocation(CGF, Loc);

clang/lib/CodeGen/CGOpenMPRuntimeGPU.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,7 @@ class CGOpenMPRuntimeGPU : public CGOpenMPRuntime {
245245
/// \param NumThreads The value corresponding to the num_threads clause, if
246246
/// any, or nullptr.
247247
/// \param NumThreadsModifier The modifier of the num_threads clause, if
248-
/// any, ignored otherwise. Currently unused on the device.
248+
/// any, ignored otherwise.
249249
/// \param Severity The severity corresponding to the num_threads clause, if
250250
/// any, ignored otherwise. Currently unused on the device.
251251
/// \param Message The message string corresponding to the num_threads clause,

openmp/device/src/Parallelism.cpp

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,10 @@ using namespace ompx;
4545

4646
namespace {
4747

48-
uint32_t determineNumberOfThreads(int32_t NumThreadsClause) {
48+
void handleStrictNumThreadsError() { __builtin_trap(); }
49+
50+
uint32_t determineNumberOfThreads(int32_t NumThreadsClause,
51+
int32_t Strict) {
4952
uint32_t NThreadsICV =
5053
NumThreadsClause != -1 ? NumThreadsClause : icv::NThreads;
5154
uint32_t NumThreads = mapping::getMaxTeamThreads();
@@ -55,13 +58,16 @@ uint32_t determineNumberOfThreads(int32_t NumThreadsClause) {
5558

5659
// SPMD mode allows any number of threads, for generic mode we round down to a
5760
// multiple of WARPSIZE since it is legal to do so in OpenMP.
58-
if (mapping::isSPMDMode())
59-
return NumThreads;
61+
if (!mapping::isSPMDMode()) {
62+
if (NumThreads < mapping::getWarpSize())
63+
NumThreads = 1;
64+
else
65+
NumThreads = (NumThreads & ~((uint32_t)mapping::getWarpSize() - 1));
66+
}
6067

61-
if (NumThreads < mapping::getWarpSize())
62-
NumThreads = 1;
63-
else
64-
NumThreads = (NumThreads & ~((uint32_t)mapping::getWarpSize() - 1));
68+
if (NumThreadsClause != -1 && Strict &&
69+
NumThreads != static_cast<uint32_t>(NumThreadsClause))
70+
handleStrictNumThreadsError();
6571

6672
return NumThreads;
6773
}
@@ -85,9 +91,10 @@ extern "C" {
8591
[[clang::always_inline]] void __kmpc_parallel_spmd(IdentTy *ident,
8692
int32_t num_threads,
8793
void *fn, void **args,
88-
const int64_t nargs) {
94+
const int64_t nargs,
95+
int32_t nt_strict) {
8996
uint32_t TId = mapping::getThreadIdInBlock();
90-
uint32_t NumThreads = determineNumberOfThreads(num_threads);
97+
uint32_t NumThreads = determineNumberOfThreads(num_threads, nt_strict);
9198
uint32_t PTeamSize =
9299
NumThreads == mapping::getMaxTeamThreads() ? 0 : NumThreads;
93100
// Avoid the race between the read of the `icv::Level` above and the write
@@ -157,6 +164,11 @@ __kmpc_parallel_60(IdentTy *ident, int32_t, int32_t if_expr,
157164
// 3) nested parallel regions
158165
if (OMP_UNLIKELY(!if_expr || state::HasThreadState ||
159166
(config::mayUseNestedParallelism() && icv::Level))) {
167+
// OpenMP 6.0 12.1.2 requires the num_threads 'strict' modifier to also have
168+
// effect when parallel execution is disabled by a corresponding if clause
169+
// attached to the parallel directive.
170+
if (nt_strict && num_threads > 1)
171+
handleStrictNumThreadsError();
160172
state::DateEnvironmentRAII DERAII(ident);
161173
++icv::Level;
162174
invokeMicrotask(TId, 0, fn, args, nargs);
@@ -170,12 +182,12 @@ __kmpc_parallel_60(IdentTy *ident, int32_t, int32_t if_expr,
170182
// This was moved to its own routine so it could be called directly
171183
// in certain situations to avoid resource consumption of unused
172184
// logic in parallel_60.
173-
__kmpc_parallel_spmd(ident, num_threads, fn, args, nargs);
185+
__kmpc_parallel_spmd(ident, num_threads, fn, args, nargs, nt_strict);
174186

175187
return;
176188
}
177189

178-
uint32_t NumThreads = determineNumberOfThreads(num_threads);
190+
uint32_t NumThreads = determineNumberOfThreads(num_threads, nt_strict);
179191
uint32_t MaxTeamThreads = mapping::getMaxTeamThreads();
180192
uint32_t PTeamSize = NumThreads == MaxTeamThreads ? 0 : NumThreads;
181193

0 commit comments

Comments
 (0)