@@ -45,7 +45,24 @@ using namespace ompx;
4545
4646namespace  {
4747
48- uint32_t  determineNumberOfThreads (int32_t  NumThreadsClause) {
48+ void  num_threads_strict_error (int32_t  nt_strict, int32_t  nt_severity,
49+                               const  char  *nt_message, int32_t  requested,
50+                               int32_t  actual) {
51+   if  (nt_message)
52+     printf (" %s\n "  , nt_message);
53+   else 
54+     printf (" The computed number of threads (%u) does not match the requested " 
55+            " number of threads (%d). Consider that it might not be supported " 
56+            " to select exactly %d threads on this target device.\n "  ,
57+            actual, requested, requested);
58+   if  (nt_severity == severity_fatal)
59+     __builtin_trap ();
60+ }
61+ 
62+ uint32_t  determineNumberOfThreads (int32_t  NumThreadsClause,
63+                                   int32_t  nt_strict = false ,
64+                                   int32_t  nt_severity = severity_fatal,
65+                                   const  char  *nt_message = nullptr ) {
4966  uint32_t  NThreadsICV =
5067      NumThreadsClause != -1  ? NumThreadsClause : icv::NThreads;
5168  uint32_t  NumThreads = mapping::getMaxTeamThreads ();
@@ -55,13 +72,17 @@ uint32_t determineNumberOfThreads(int32_t NumThreadsClause) {
5572
5673  //  SPMD mode allows any number of threads, for generic mode we round down to a
5774  //  multiple of WARPSIZE since it is legal to do so in OpenMP.
58-   if  (mapping::isSPMDMode ())
59-     return  NumThreads;
75+   if  (!mapping::isSPMDMode ()) {
76+     if  (NumThreads < mapping::getWarpSize ())
77+       NumThreads = 1 ;
78+     else 
79+       NumThreads = (NumThreads & ~((uint32_t )mapping::getWarpSize () - 1 ));
80+   }
6081
61-   if  (NumThreads <  mapping::getWarpSize ()) 
62-     NumThreads =  1 ; 
63-   else 
64-     NumThreads = (NumThreads & ~(( uint32_t ) mapping::getWarpSize () -  1 ) );
82+   if  (NumThreadsClause != - 1  && nt_strict && 
83+        NumThreads !=  static_cast < uint32_t >(NumThreadsClause)) 
84+      num_threads_strict_error (nt_strict, nt_severity, nt_message, 
85+                              NumThreadsClause, NumThreads );
6586
6687  return  NumThreads;
6788}
@@ -82,12 +103,14 @@ uint32_t determineNumberOfThreads(int32_t NumThreadsClause) {
82103
83104extern  " C"   {
84105
85- [[clang::always_inline]] void  __kmpc_parallel_spmd (IdentTy *ident,
86-                                                    int32_t  num_threads,
87-                                                    void  *fn, void  **args,
88-                                                    const  int64_t  nargs) {
106+ [[clang::always_inline]] void 
107+ __kmpc_parallel_spmd (IdentTy *ident, int32_t  num_threads, void  *fn, void  **args,
108+                      const  int64_t  nargs, int32_t  nt_strict = false ,
109+                      int32_t  nt_severity = severity_fatal,
110+                      const  char  *nt_message = nullptr ) {
89111  uint32_t  TId = mapping::getThreadIdInBlock ();
90-   uint32_t  NumThreads = determineNumberOfThreads (num_threads);
112+   uint32_t  NumThreads =
113+       determineNumberOfThreads (num_threads, nt_strict, nt_severity, nt_message);
91114  uint32_t  PTeamSize =
92115      NumThreads == mapping::getMaxTeamThreads () ? 0  : NumThreads;
93116  //  Avoid the race between the read of the `icv::Level` above and the write
@@ -140,10 +163,11 @@ extern "C" {
140163  return ;
141164}
142165
143- [[clang::always_inline]] void 
144- __kmpc_parallel_51 (IdentTy *ident, int32_t , int32_t  if_expr,
145-                    int32_t  num_threads, int  proc_bind, void  *fn,
146-                    void  *wrapper_fn, void  **args, int64_t  nargs) {
166+ [[clang::always_inline]] void  __kmpc_parallel_51 (
167+     IdentTy *ident, int32_t , int32_t  if_expr, int32_t  num_threads,
168+     int  proc_bind, void  *fn, void  *wrapper_fn, void  **args, int64_t  nargs,
169+     int32_t  nt_strict = false , int32_t  nt_severity = severity_fatal,
170+     const  char  *nt_message = nullptr ) {
147171  uint32_t  TId = mapping::getThreadIdInBlock ();
148172
149173  //  Assert the parallelism level is zero if disabled by the user.
@@ -156,6 +180,12 @@ __kmpc_parallel_51(IdentTy *ident, int32_t, int32_t if_expr,
156180  //  3) nested parallel regions
157181  if  (OMP_UNLIKELY (!if_expr || state::HasThreadState ||
158182                   (config::mayUseNestedParallelism () && icv::Level))) {
183+     //  OpenMP 6.0 12.1.2 requires the num_threads 'strict' modifier to also have
184+     //  effect when parallel execution is disabled by a corresponding if clause
185+     //  attached to the parallel directive.
186+     if  (nt_strict && num_threads > 1 )
187+       num_threads_strict_error (nt_strict, nt_severity, nt_message, num_threads,
188+                                1 );
159189    state::DateEnvironmentRAII DERAII (ident);
160190    ++icv::Level;
161191    invokeMicrotask (TId, 0 , fn, args, nargs);
@@ -169,12 +199,14 @@ __kmpc_parallel_51(IdentTy *ident, int32_t, int32_t if_expr,
169199    //  This was moved to its own routine so it could be called directly
170200    //  in certain situations to avoid resource consumption of unused
171201    //  logic in parallel_51.
172-     __kmpc_parallel_spmd (ident, num_threads, fn, args, nargs);
202+     __kmpc_parallel_spmd (ident, num_threads, fn, args, nargs, nt_strict,
203+                          nt_severity, nt_message);
173204
174205    return ;
175206  }
176207
177-   uint32_t  NumThreads = determineNumberOfThreads (num_threads);
208+   uint32_t  NumThreads =
209+       determineNumberOfThreads (num_threads, nt_strict, nt_severity, nt_message);
178210  uint32_t  MaxTeamThreads = mapping::getMaxTeamThreads ();
179211  uint32_t  PTeamSize = NumThreads == MaxTeamThreads ? 0  : NumThreads;
180212
@@ -277,6 +309,16 @@ __kmpc_parallel_51(IdentTy *ident, int32_t, int32_t if_expr,
277309    __kmpc_end_sharing_variables ();
278310}
279311
312+ [[clang::always_inline]] void  __kmpc_parallel_60 (
313+     IdentTy *ident, int32_t  id, int32_t  if_expr, int32_t  num_threads,
314+     int  proc_bind, void  *fn, void  *wrapper_fn, void  **args, int64_t  nargs,
315+     int32_t  nt_strict = false , int32_t  nt_severity = severity_fatal,
316+     const  char  *nt_message = nullptr ) {
317+   return  __kmpc_parallel_51 (ident, id, if_expr, num_threads, proc_bind, fn,
318+                             wrapper_fn, args, nargs, nt_strict, nt_severity,
319+                             nt_message);
320+ }
321+ 
280322[[clang::noinline]] bool  __kmpc_kernel_parallel (ParallelRegionFnTy *WorkFn) {
281323  //  Work function and arguments for L1 parallel region.
282324  *WorkFn = state::ParallelRegionFn;
0 commit comments