@@ -45,7 +45,24 @@ using namespace ompx;
4545
4646namespace {
4747
48- uint32_t determineNumberOfThreads (int32_t NumThreadsClause) {
48+ void numThreadsStrictError (int32_t nt_strict, int32_t nt_severity,
49+ const char *nt_message, int32_t requested,
50+ int32_t actual) {
51+ if (nt_message)
52+ printf (" %s\n " , nt_message);
53+ else
54+ printf (" The computed number of threads (%u) does not match the requested "
55+ " number of threads (%d). Consider that it might not be supported "
56+ " to select exactly %d threads on this target device.\n " ,
57+ actual, requested, requested);
58+ if (nt_severity == severity_fatal)
59+ __builtin_trap ();
60+ }
61+
62+ uint32_t determineNumberOfThreads (int32_t NumThreadsClause,
63+ int32_t nt_strict = false ,
64+ int32_t nt_severity = severity_fatal,
65+ const char *nt_message = nullptr ) {
4966 uint32_t NThreadsICV =
5067 NumThreadsClause != -1 ? NumThreadsClause : icv::NThreads;
5168 uint32_t NumThreads = mapping::getMaxTeamThreads ();
@@ -55,13 +72,17 @@ uint32_t determineNumberOfThreads(int32_t NumThreadsClause) {
5572
5673 // SPMD mode allows any number of threads, for generic mode we round down to a
5774 // multiple of WARPSIZE since it is legal to do so in OpenMP.
58- if (mapping::isSPMDMode ())
59- return NumThreads;
75+ if (!mapping::isSPMDMode ()) {
76+ if (NumThreads < mapping::getWarpSize ())
77+ NumThreads = 1 ;
78+ else
79+ NumThreads = (NumThreads & ~((uint32_t )mapping::getWarpSize () - 1 ));
80+ }
6081
61- if (NumThreads < mapping::getWarpSize ())
62- NumThreads = 1 ;
63- else
64- NumThreads = (NumThreads & ~(( uint32_t ) mapping::getWarpSize () - 1 ) );
82+ if (NumThreadsClause != - 1 && nt_strict &&
83+ NumThreads != static_cast < uint32_t >(NumThreadsClause))
84+ numThreadsStrictError (nt_strict, nt_severity, nt_message, NumThreadsClause,
85+ NumThreads );
6586
6687 return NumThreads;
6788}
@@ -82,12 +103,13 @@ uint32_t determineNumberOfThreads(int32_t NumThreadsClause) {
82103
83104extern " C" {
84105
85- [[clang::always_inline]] void __kmpc_parallel_spmd (IdentTy *ident,
86- int32_t num_threads ,
87- void *fn, void **args ,
88- const int64_t nargs ) {
106+ [[clang::always_inline]] void __kmpc_parallel_spmd_impl (
107+ IdentTy *ident, int32_t num_threads, void *fn, void **args ,
108+ const int64_t nargs, int32_t nt_strict = false ,
109+ int32_t nt_severity = severity_fatal, const char *nt_message = nullptr ) {
89110 uint32_t TId = mapping::getThreadIdInBlock ();
90- uint32_t NumThreads = determineNumberOfThreads (num_threads);
111+ uint32_t NumThreads =
112+ determineNumberOfThreads (num_threads, nt_strict, nt_severity, nt_message);
91113 uint32_t PTeamSize =
92114 NumThreads == mapping::getMaxTeamThreads () ? 0 : NumThreads;
93115 // Avoid the race between the read of the `icv::Level` above and the write
@@ -140,10 +162,26 @@ extern "C" {
140162 return ;
141163}
142164
143- [[clang::always_inline]] void
144- __kmpc_parallel_51 (IdentTy *ident, int32_t , int32_t if_expr,
145- int32_t num_threads, int proc_bind, void *fn,
146- void *wrapper_fn, void **args, int64_t nargs) {
165+ [[clang::always_inline]] void __kmpc_parallel_spmd (IdentTy *ident,
166+ int32_t num_threads,
167+ void *fn, void **args,
168+ const int64_t nargs) {
169+ return __kmpc_parallel_spmd_impl (ident, num_threads, fn, args, nargs);
170+ }
171+
172+ [[clang::always_inline]] void __kmpc_parallel_spmd_60 (
173+ IdentTy *ident, int32_t num_threads, void *fn, void **args,
174+ const int64_t nargs, int32_t nt_strict = false ,
175+ int32_t nt_severity = severity_fatal, const char *nt_message = nullptr ) {
176+ return __kmpc_parallel_spmd_impl (ident, num_threads, fn, args, nargs,
177+ nt_strict, nt_severity, nt_message);
178+ }
179+
180+ [[clang::always_inline]] void __kmpc_parallel_impl (
181+ IdentTy *ident, int32_t , int32_t if_expr, int32_t num_threads,
182+ int proc_bind, void *fn, void *wrapper_fn, void **args, int64_t nargs,
183+ int32_t nt_strict = false , int32_t nt_severity = severity_fatal,
184+ const char *nt_message = nullptr ) {
147185 uint32_t TId = mapping::getThreadIdInBlock ();
148186
149187 // Assert the parallelism level is zero if disabled by the user.
@@ -156,6 +194,11 @@ __kmpc_parallel_51(IdentTy *ident, int32_t, int32_t if_expr,
156194 // 3) nested parallel regions
157195 if (OMP_UNLIKELY (!if_expr || state::HasThreadState ||
158196 (config::mayUseNestedParallelism () && icv::Level))) {
197+ // OpenMP 6.0 12.1.2 requires the num_threads 'strict' modifier to also have
198+ // effect when parallel execution is disabled by a corresponding if clause
199+ // attached to the parallel directive.
200+ if (nt_strict && num_threads > 1 )
201+ numThreadsStrictError (nt_strict, nt_severity, nt_message, num_threads, 1 );
159202 state::DateEnvironmentRAII DERAII (ident);
160203 ++icv::Level;
161204 invokeMicrotask (TId, 0 , fn, args, nargs);
@@ -169,12 +212,17 @@ __kmpc_parallel_51(IdentTy *ident, int32_t, int32_t if_expr,
169212 // This was moved to its own routine so it could be called directly
170213 // in certain situations to avoid resource consumption of unused
171214 // logic in parallel_51.
172- __kmpc_parallel_spmd (ident, num_threads, fn, args, nargs);
215+ if (nt_strict)
216+ __kmpc_parallel_spmd (ident, num_threads, fn, args, nargs);
217+ else
218+ __kmpc_parallel_spmd_60 (ident, num_threads, fn, args, nargs, nt_strict,
219+ nt_severity, nt_message);
173220
174221 return ;
175222 }
176223
177- uint32_t NumThreads = determineNumberOfThreads (num_threads);
224+ uint32_t NumThreads =
225+ determineNumberOfThreads (num_threads, nt_strict, nt_severity, nt_message);
178226 uint32_t MaxTeamThreads = mapping::getMaxTeamThreads ();
179227 uint32_t PTeamSize = NumThreads == MaxTeamThreads ? 0 : NumThreads;
180228
@@ -277,6 +325,24 @@ __kmpc_parallel_51(IdentTy *ident, int32_t, int32_t if_expr,
277325 __kmpc_end_sharing_variables ();
278326}
279327
328+ [[clang::always_inline]] void
329+ __kmpc_parallel_51 (IdentTy *ident, int32_t id, int32_t if_expr,
330+ int32_t num_threads, int proc_bind, void *fn,
331+ void *wrapper_fn, void **args, int64_t nargs) {
332+ return __kmpc_parallel_impl (ident, id, if_expr, num_threads, proc_bind, fn,
333+ wrapper_fn, args, nargs);
334+ }
335+
336+ [[clang::always_inline]] void __kmpc_parallel_60 (
337+ IdentTy *ident, int32_t id, int32_t if_expr, int32_t num_threads,
338+ int proc_bind, void *fn, void *wrapper_fn, void **args, int64_t nargs,
339+ int32_t nt_strict = false , int32_t nt_severity = severity_fatal,
340+ const char *nt_message = nullptr ) {
341+ return __kmpc_parallel_impl (ident, id, if_expr, num_threads, proc_bind, fn,
342+ wrapper_fn, args, nargs, nt_strict, nt_severity,
343+ nt_message);
344+ }
345+
280346[[clang::noinline]] bool __kmpc_kernel_parallel (ParallelRegionFnTy *WorkFn) {
281347 // Work function and arguments for L1 parallel region.
282348 *WorkFn = state::ParallelRegionFn;
0 commit comments