@@ -45,7 +45,10 @@ using namespace ompx;
4545
4646namespace {
4747
48- uint32_t determineNumberOfThreads (int32_t NumThreadsClause) {
48+ void handleStrictNumThreadsError () { __builtin_trap (); }
49+
50+ uint32_t determineNumberOfThreads (int32_t NumThreadsClause,
51+ int32_t Strict) {
4952 uint32_t NThreadsICV =
5053 NumThreadsClause != -1 ? NumThreadsClause : icv::NThreads;
5154 uint32_t NumThreads = mapping::getMaxTeamThreads ();
@@ -55,13 +58,16 @@ uint32_t determineNumberOfThreads(int32_t NumThreadsClause) {
5558
5659 // SPMD mode allows any number of threads, for generic mode we round down to a
5760 // multiple of WARPSIZE since it is legal to do so in OpenMP.
58- if (mapping::isSPMDMode ())
59- return NumThreads;
61+ if (!mapping::isSPMDMode ()) {
62+ if (NumThreads < mapping::getWarpSize ())
63+ NumThreads = 1 ;
64+ else
65+ NumThreads = (NumThreads & ~((uint32_t )mapping::getWarpSize () - 1 ));
66+ }
6067
61- if (NumThreads < mapping::getWarpSize ())
62- NumThreads = 1 ;
63- else
64- NumThreads = (NumThreads & ~((uint32_t )mapping::getWarpSize () - 1 ));
68+ if (NumThreadsClause != -1 && Strict &&
69+ NumThreads != static_cast <uint32_t >(NumThreadsClause))
70+ handleStrictNumThreadsError ();
6571
6672 return NumThreads;
6773}
@@ -85,9 +91,10 @@ extern "C" {
8591[[clang::always_inline]] void __kmpc_parallel_spmd (IdentTy *ident,
8692 int32_t num_threads,
8793 void *fn, void **args,
88- const int64_t nargs) {
94+ const int64_t nargs,
95+ int32_t nt_strict) {
8996 uint32_t TId = mapping::getThreadIdInBlock ();
90- uint32_t NumThreads = determineNumberOfThreads (num_threads);
97+ uint32_t NumThreads = determineNumberOfThreads (num_threads, nt_strict );
9198 uint32_t PTeamSize =
9299 NumThreads == mapping::getMaxTeamThreads () ? 0 : NumThreads;
93100 // Avoid the race between the read of the `icv::Level` above and the write
@@ -157,6 +164,11 @@ __kmpc_parallel_60(IdentTy *ident, int32_t, int32_t if_expr,
157164 // 3) nested parallel regions
158165 if (OMP_UNLIKELY (!if_expr || state::HasThreadState ||
159166 (config::mayUseNestedParallelism () && icv::Level))) {
167+ // OpenMP 6.0 12.1.2 requires the num_threads 'strict' modifier to also have
168+ // effect when parallel execution is disabled by a corresponding if clause
169+ // attached to the parallel directive.
170+ if (nt_strict && num_threads > 1 )
171+ handleStrictNumThreadsError ();
160172 state::DateEnvironmentRAII DERAII (ident);
161173 ++icv::Level;
162174 invokeMicrotask (TId, 0 , fn, args, nargs);
@@ -170,12 +182,12 @@ __kmpc_parallel_60(IdentTy *ident, int32_t, int32_t if_expr,
170182 // This was moved to its own routine so it could be called directly
171183 // in certain situations to avoid resource consumption of unused
172184 // logic in parallel_60.
173- __kmpc_parallel_spmd (ident, num_threads, fn, args, nargs);
185+ __kmpc_parallel_spmd (ident, num_threads, fn, args, nargs, nt_strict );
174186
175187 return ;
176188 }
177189
178- uint32_t NumThreads = determineNumberOfThreads (num_threads);
190+ uint32_t NumThreads = determineNumberOfThreads (num_threads, nt_strict );
179191 uint32_t MaxTeamThreads = mapping::getMaxTeamThreads ();
180192 uint32_t PTeamSize = NumThreads == MaxTeamThreads ? 0 : NumThreads;
181193
0 commit comments