@@ -45,7 +45,24 @@ using namespace ompx;
4545
4646namespace {
4747
48- uint32_t determineNumberOfThreads (int32_t NumThreadsClause) {
48+ void num_threads_strict_error (int32_t nt_strict, int32_t nt_severity,
49+ const char *nt_message, int32_t requested,
50+ int32_t actual) {
51+ if (nt_message)
52+ printf (" %s\n " , nt_message);
53+ else
54+ printf (" The computed number of threads (%u) does not match the requested "
55+ " number of threads (%d). Consider that it might not be supported "
56+ " to select exactly %d threads on this target device.\n " ,
57+ actual, requested, requested);
58+ if (nt_severity == severity_fatal)
59+ __builtin_trap ();
60+ }
61+
62+ uint32_t determineNumberOfThreads (int32_t NumThreadsClause,
63+ int32_t nt_strict = false ,
64+ int32_t nt_severity = severity_fatal,
65+ const char *nt_message = nullptr ) {
4966 uint32_t NThreadsICV =
5067 NumThreadsClause != -1 ? NumThreadsClause : icv::NThreads;
5168 uint32_t NumThreads = mapping::getMaxTeamThreads ();
@@ -55,13 +72,17 @@ uint32_t determineNumberOfThreads(int32_t NumThreadsClause) {
5572
5673 // SPMD mode allows any number of threads, for generic mode we round down to a
5774 // multiple of WARPSIZE since it is legal to do so in OpenMP.
58- if (mapping::isSPMDMode ())
59- return NumThreads;
75+ if (!mapping::isSPMDMode ()) {
76+ if (NumThreads < mapping::getWarpSize ())
77+ NumThreads = 1 ;
78+ else
79+ NumThreads = (NumThreads & ~((uint32_t )mapping::getWarpSize () - 1 ));
80+ }
6081
61- if (NumThreads < mapping::getWarpSize ())
62- NumThreads = 1 ;
63- else
64- NumThreads = (NumThreads & ~(( uint32_t ) mapping::getWarpSize () - 1 ) );
82+ if (NumThreadsClause != - 1 && nt_strict &&
83+ NumThreads != static_cast < uint32_t >(NumThreadsClause))
84+ num_threads_strict_error (nt_strict, nt_severity, nt_message,
85+ NumThreadsClause, NumThreads );
6586
6687 return NumThreads;
6788}
@@ -82,12 +103,14 @@ uint32_t determineNumberOfThreads(int32_t NumThreadsClause) {
82103
83104extern " C" {
84105
85- [[clang::always_inline]] void __kmpc_parallel_spmd (IdentTy *ident,
86- int32_t num_threads,
87- void *fn, void **args,
88- const int64_t nargs) {
106+ [[clang::always_inline]] void
107+ __kmpc_parallel_spmd (IdentTy *ident, int32_t num_threads, void *fn, void **args,
108+ const int64_t nargs, int32_t nt_strict = false ,
109+ int32_t nt_severity = severity_fatal,
110+ const char *nt_message = nullptr ) {
89111 uint32_t TId = mapping::getThreadIdInBlock ();
90- uint32_t NumThreads = determineNumberOfThreads (num_threads);
112+ uint32_t NumThreads =
113+ determineNumberOfThreads (num_threads, nt_strict, nt_severity, nt_message);
91114 uint32_t PTeamSize =
92115 NumThreads == mapping::getMaxTeamThreads () ? 0 : NumThreads;
93116 // Avoid the race between the read of the `icv::Level` above and the write
@@ -140,10 +163,11 @@ extern "C" {
140163 return ;
141164}
142165
143- [[clang::always_inline]] void
144- __kmpc_parallel_51 (IdentTy *ident, int32_t , int32_t if_expr,
145- int32_t num_threads, int proc_bind, void *fn,
146- void *wrapper_fn, void **args, int64_t nargs) {
166+ [[clang::always_inline]] void __kmpc_parallel_51 (
167+ IdentTy *ident, int32_t , int32_t if_expr, int32_t num_threads,
168+ int proc_bind, void *fn, void *wrapper_fn, void **args, int64_t nargs,
169+ int32_t nt_strict = false , int32_t nt_severity = severity_fatal,
170+ const char *nt_message = nullptr ) {
147171 uint32_t TId = mapping::getThreadIdInBlock ();
148172
149173 // Assert the parallelism level is zero if disabled by the user.
@@ -156,6 +180,12 @@ __kmpc_parallel_51(IdentTy *ident, int32_t, int32_t if_expr,
156180 // 3) nested parallel regions
157181 if (OMP_UNLIKELY (!if_expr || state::HasThreadState ||
158182 (config::mayUseNestedParallelism () && icv::Level))) {
183+ // OpenMP 6.0 12.1.2 requires the num_threads 'strict' modifier to also have
184+ // effect when parallel execution is disabled by a corresponding if clause
185+ // attached to the parallel directive.
186+ if (nt_strict && num_threads > 1 )
187+ num_threads_strict_error (nt_strict, nt_severity, nt_message, num_threads,
188+ 1 );
159189 state::DateEnvironmentRAII DERAII (ident);
160190 ++icv::Level;
161191 invokeMicrotask (TId, 0 , fn, args, nargs);
@@ -169,12 +199,14 @@ __kmpc_parallel_51(IdentTy *ident, int32_t, int32_t if_expr,
169199 // This was moved to its own routine so it could be called directly
170200 // in certain situations to avoid resource consumption of unused
171201 // logic in parallel_51.
172- __kmpc_parallel_spmd (ident, num_threads, fn, args, nargs);
202+ __kmpc_parallel_spmd (ident, num_threads, fn, args, nargs, nt_strict,
203+ nt_severity, nt_message);
173204
174205 return ;
175206 }
176207
177- uint32_t NumThreads = determineNumberOfThreads (num_threads);
208+ uint32_t NumThreads =
209+ determineNumberOfThreads (num_threads, nt_strict, nt_severity, nt_message);
178210 uint32_t MaxTeamThreads = mapping::getMaxTeamThreads ();
179211 uint32_t PTeamSize = NumThreads == MaxTeamThreads ? 0 : NumThreads;
180212
@@ -277,6 +309,16 @@ __kmpc_parallel_51(IdentTy *ident, int32_t, int32_t if_expr,
277309 __kmpc_end_sharing_variables ();
278310}
279311
312+ [[clang::always_inline]] void __kmpc_parallel_60 (
313+ IdentTy *ident, int32_t id, int32_t if_expr, int32_t num_threads,
314+ int proc_bind, void *fn, void *wrapper_fn, void **args, int64_t nargs,
315+ int32_t nt_strict = false , int32_t nt_severity = severity_fatal,
316+ const char *nt_message = nullptr ) {
317+ return __kmpc_parallel_51 (ident, id, if_expr, num_threads, proc_bind, fn,
318+ wrapper_fn, args, nargs, nt_strict, nt_severity,
319+ nt_message);
320+ }
321+
280322[[clang::noinline]] bool __kmpc_kernel_parallel (ParallelRegionFnTy *WorkFn) {
281323 // Work function and arguments for L1 parallel region.
282324 *WorkFn = state::ParallelRegionFn;
0 commit comments