Skip to content

Commit 7d86769

Browse files
authored
Reland: [OpenMP][clang] 6.0: num_threads strict (part 2: device runtime) (llvm#146404) (llvm#3805)
2 parents 63747a4 + 0bc7353 commit 7d86769

File tree

2 files changed

+90
-18
lines changed

2 files changed

+90
-18
lines changed

offload/DeviceRTL/include/DeviceTypes.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,12 @@ struct omp_lock_t {
137137
void *Lock;
138138
};
139139

140+
// see definition in openmp/runtime kmp.h
141+
typedef enum omp_severity_t {
142+
severity_warning = 1,
143+
severity_fatal = 2
144+
} omp_severity_t;
145+
140146
using InterWarpCopyFnTy = void (*)(void *src, int32_t warp_num);
141147
using ShuffleReductFnTy = void (*)(void *rhsData, int16_t lane_id,
142148
int16_t lane_offset, int16_t shortCircuit);

offload/DeviceRTL/src/Parallelism.cpp

Lines changed: 84 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,24 @@ using namespace ompx;
4545

4646
namespace {
4747

48-
uint32_t determineNumberOfThreads(int32_t NumThreadsClause) {
48+
void numThreadsStrictError(int32_t nt_strict, int32_t nt_severity,
49+
const char *nt_message, int32_t requested,
50+
int32_t actual) {
51+
if (nt_message)
52+
printf("%s\n", nt_message);
53+
else
54+
printf("The computed number of threads (%u) does not match the requested "
55+
"number of threads (%d). Consider that it might not be supported "
56+
"to select exactly %d threads on this target device.\n",
57+
actual, requested, requested);
58+
if (nt_severity == severity_fatal)
59+
__builtin_trap();
60+
}
61+
62+
uint32_t determineNumberOfThreads(int32_t NumThreadsClause,
63+
int32_t nt_strict = false,
64+
int32_t nt_severity = severity_fatal,
65+
const char *nt_message = nullptr) {
4966
uint32_t NThreadsICV =
5067
NumThreadsClause != -1 ? NumThreadsClause : icv::NThreads;
5168
uint32_t NumThreads = mapping::getMaxTeamThreads();
@@ -55,13 +72,17 @@ uint32_t determineNumberOfThreads(int32_t NumThreadsClause) {
5572

5673
// SPMD mode allows any number of threads, for generic mode we round down to a
5774
// multiple of WARPSIZE since it is legal to do so in OpenMP.
58-
if (mapping::isSPMDMode())
59-
return NumThreads;
75+
if (!mapping::isSPMDMode()) {
76+
if (NumThreads < mapping::getWarpSize())
77+
NumThreads = 1;
78+
else
79+
NumThreads = (NumThreads & ~((uint32_t)mapping::getWarpSize() - 1));
80+
}
6081

61-
if (NumThreads < mapping::getWarpSize())
62-
NumThreads = 1;
63-
else
64-
NumThreads = (NumThreads & ~((uint32_t)mapping::getWarpSize() - 1));
82+
if (NumThreadsClause != -1 && nt_strict &&
83+
NumThreads != static_cast<uint32_t>(NumThreadsClause))
84+
numThreadsStrictError(nt_strict, nt_severity, nt_message, NumThreadsClause,
85+
NumThreads);
6586

6687
return NumThreads;
6788
}
@@ -82,12 +103,13 @@ uint32_t determineNumberOfThreads(int32_t NumThreadsClause) {
82103

83104
extern "C" {
84105

85-
[[clang::always_inline]] void __kmpc_parallel_spmd(IdentTy *ident,
86-
int32_t num_threads,
87-
void *fn, void **args,
88-
const int64_t nargs) {
106+
[[clang::always_inline]] void __kmpc_parallel_spmd_impl(
107+
IdentTy *ident, int32_t num_threads, void *fn, void **args,
108+
const int64_t nargs, int32_t nt_strict = false,
109+
int32_t nt_severity = severity_fatal, const char *nt_message = nullptr) {
89110
uint32_t TId = mapping::getThreadIdInBlock();
90-
uint32_t NumThreads = determineNumberOfThreads(num_threads);
111+
uint32_t NumThreads =
112+
determineNumberOfThreads(num_threads, nt_strict, nt_severity, nt_message);
91113
uint32_t PTeamSize =
92114
NumThreads == mapping::getMaxTeamThreads() ? 0 : NumThreads;
93115
// Avoid the race between the read of the `icv::Level` above and the write
@@ -140,10 +162,26 @@ extern "C" {
140162
return;
141163
}
142164

143-
[[clang::always_inline]] void
144-
__kmpc_parallel_51(IdentTy *ident, int32_t, int32_t if_expr,
145-
int32_t num_threads, int proc_bind, void *fn,
146-
void *wrapper_fn, void **args, int64_t nargs) {
165+
[[clang::always_inline]] void __kmpc_parallel_spmd(IdentTy *ident,
166+
int32_t num_threads,
167+
void *fn, void **args,
168+
const int64_t nargs) {
169+
return __kmpc_parallel_spmd_impl(ident, num_threads, fn, args, nargs);
170+
}
171+
172+
[[clang::always_inline]] void __kmpc_parallel_spmd_60(
173+
IdentTy *ident, int32_t num_threads, void *fn, void **args,
174+
const int64_t nargs, int32_t nt_strict = false,
175+
int32_t nt_severity = severity_fatal, const char *nt_message = nullptr) {
176+
return __kmpc_parallel_spmd_impl(ident, num_threads, fn, args, nargs,
177+
nt_strict, nt_severity, nt_message);
178+
}
179+
180+
[[clang::always_inline]] void __kmpc_parallel_impl(
181+
IdentTy *ident, int32_t, int32_t if_expr, int32_t num_threads,
182+
int proc_bind, void *fn, void *wrapper_fn, void **args, int64_t nargs,
183+
int32_t nt_strict = false, int32_t nt_severity = severity_fatal,
184+
const char *nt_message = nullptr) {
147185
uint32_t TId = mapping::getThreadIdInBlock();
148186

149187
// Assert the parallelism level is zero if disabled by the user.
@@ -156,6 +194,11 @@ __kmpc_parallel_51(IdentTy *ident, int32_t, int32_t if_expr,
156194
// 3) nested parallel regions
157195
if (OMP_UNLIKELY(!if_expr || state::HasThreadState ||
158196
(config::mayUseNestedParallelism() && icv::Level))) {
197+
// OpenMP 6.0 12.1.2 requires the num_threads 'strict' modifier to also have
198+
// effect when parallel execution is disabled by a corresponding if clause
199+
// attached to the parallel directive.
200+
if (nt_strict && num_threads > 1)
201+
numThreadsStrictError(nt_strict, nt_severity, nt_message, num_threads, 1);
159202
state::DateEnvironmentRAII DERAII(ident);
160203
++icv::Level;
161204
invokeMicrotask(TId, 0, fn, args, nargs);
@@ -169,12 +212,17 @@ __kmpc_parallel_51(IdentTy *ident, int32_t, int32_t if_expr,
169212
// This was moved to its own routine so it could be called directly
170213
// in certain situations to avoid resource consumption of unused
171214
// logic in parallel_51.
172-
__kmpc_parallel_spmd(ident, num_threads, fn, args, nargs);
215+
if (nt_strict)
216+
__kmpc_parallel_spmd(ident, num_threads, fn, args, nargs);
217+
else
218+
__kmpc_parallel_spmd_60(ident, num_threads, fn, args, nargs, nt_strict,
219+
nt_severity, nt_message);
173220

174221
return;
175222
}
176223

177-
uint32_t NumThreads = determineNumberOfThreads(num_threads);
224+
uint32_t NumThreads =
225+
determineNumberOfThreads(num_threads, nt_strict, nt_severity, nt_message);
178226
uint32_t MaxTeamThreads = mapping::getMaxTeamThreads();
179227
uint32_t PTeamSize = NumThreads == MaxTeamThreads ? 0 : NumThreads;
180228

@@ -277,6 +325,24 @@ __kmpc_parallel_51(IdentTy *ident, int32_t, int32_t if_expr,
277325
__kmpc_end_sharing_variables();
278326
}
279327

328+
[[clang::always_inline]] void
329+
__kmpc_parallel_51(IdentTy *ident, int32_t id, int32_t if_expr,
330+
int32_t num_threads, int proc_bind, void *fn,
331+
void *wrapper_fn, void **args, int64_t nargs) {
332+
return __kmpc_parallel_impl(ident, id, if_expr, num_threads, proc_bind, fn,
333+
wrapper_fn, args, nargs);
334+
}
335+
336+
[[clang::always_inline]] void __kmpc_parallel_60(
337+
IdentTy *ident, int32_t id, int32_t if_expr, int32_t num_threads,
338+
int proc_bind, void *fn, void *wrapper_fn, void **args, int64_t nargs,
339+
int32_t nt_strict = false, int32_t nt_severity = severity_fatal,
340+
const char *nt_message = nullptr) {
341+
return __kmpc_parallel_impl(ident, id, if_expr, num_threads, proc_bind, fn,
342+
wrapper_fn, args, nargs, nt_strict, nt_severity,
343+
nt_message);
344+
}
345+
280346
[[clang::noinline]] bool __kmpc_kernel_parallel(ParallelRegionFnTy *WorkFn) {
281347
// Work function and arguments for L1 parallel region.
282348
*WorkFn = state::ParallelRegionFn;

0 commit comments

Comments
 (0)