Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions offload/DeviceRTL/include/DeviceTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,12 @@ struct omp_lock_t {
void *Lock;
};

// see definition in openmp/runtime kmp.h
typedef enum omp_severity_t {
severity_warning = 1,
severity_fatal = 2
} omp_severity_t;

using InterWarpCopyFnTy = void (*)(void *src, int32_t warp_num);
using ShuffleReductFnTy = void (*)(void *rhsData, int16_t lane_id,
int16_t lane_offset, int16_t shortCircuit);
Expand Down
77 changes: 59 additions & 18 deletions offload/DeviceRTL/src/Parallelism.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,24 @@ using namespace ompx;

namespace {

uint32_t determineNumberOfThreads(int32_t NumThreadsClause) {
void numThreadsStrictError(int32_t nt_strict, int32_t nt_severity,
const char *nt_message, int32_t requested,
int32_t actual) {
if (nt_message)
printf("%s\n", nt_message);
else
printf("The computed number of threads (%u) does not match the requested "
"number of threads (%d). Consider that it might not be supported "
"to select exactly %d threads on this target device.\n",
actual, requested, requested);
if (nt_severity == severity_fatal)
__builtin_trap();
}

uint32_t determineNumberOfThreads(int32_t NumThreadsClause,
int32_t nt_strict = false,
int32_t nt_severity = severity_fatal,
const char *nt_message = nullptr) {
uint32_t NThreadsICV =
NumThreadsClause != -1 ? NumThreadsClause : icv::NThreads;
uint32_t NumThreads = mapping::getMaxTeamThreads();
Expand All @@ -55,13 +72,17 @@ uint32_t determineNumberOfThreads(int32_t NumThreadsClause) {

// SPMD mode allows any number of threads, for generic mode we round down to a
// multiple of WARPSIZE since it is legal to do so in OpenMP.
if (mapping::isSPMDMode())
return NumThreads;
if (!mapping::isSPMDMode()) {
if (NumThreads < mapping::getWarpSize())
NumThreads = 1;
else
NumThreads = (NumThreads & ~((uint32_t)mapping::getWarpSize() - 1));
}

if (NumThreads < mapping::getWarpSize())
NumThreads = 1;
else
NumThreads = (NumThreads & ~((uint32_t)mapping::getWarpSize() - 1));
if (NumThreadsClause != -1 && nt_strict &&
NumThreads != static_cast<uint32_t>(NumThreadsClause))
numThreadsStrictError(nt_strict, nt_severity, nt_message, NumThreadsClause,
NumThreads);

return NumThreads;
}
Expand All @@ -82,12 +103,14 @@ uint32_t determineNumberOfThreads(int32_t NumThreadsClause) {

extern "C" {

[[clang::always_inline]] void __kmpc_parallel_spmd(IdentTy *ident,
int32_t num_threads,
void *fn, void **args,
const int64_t nargs) {
[[clang::always_inline]] void
__kmpc_parallel_spmd(IdentTy *ident, int32_t num_threads, void *fn, void **args,
const int64_t nargs, int32_t nt_strict = false,
int32_t nt_severity = severity_fatal,
const char *nt_message = nullptr) {
uint32_t TId = mapping::getThreadIdInBlock();
uint32_t NumThreads = determineNumberOfThreads(num_threads);
uint32_t NumThreads =
determineNumberOfThreads(num_threads, nt_strict, nt_severity, nt_message);
uint32_t PTeamSize =
NumThreads == mapping::getMaxTeamThreads() ? 0 : NumThreads;
// Avoid the race between the read of the `icv::Level` above and the write
Expand Down Expand Up @@ -140,10 +163,11 @@ extern "C" {
return;
}

[[clang::always_inline]] void
__kmpc_parallel_51(IdentTy *ident, int32_t, int32_t if_expr,
int32_t num_threads, int proc_bind, void *fn,
void *wrapper_fn, void **args, int64_t nargs) {
[[clang::always_inline]] void __kmpc_parallel_51(
IdentTy *ident, int32_t, int32_t if_expr, int32_t num_threads,
int proc_bind, void *fn, void *wrapper_fn, void **args, int64_t nargs,
int32_t nt_strict = false, int32_t nt_severity = severity_fatal,
const char *nt_message = nullptr) {
uint32_t TId = mapping::getThreadIdInBlock();

// Assert the parallelism level is zero if disabled by the user.
Expand All @@ -156,6 +180,11 @@ __kmpc_parallel_51(IdentTy *ident, int32_t, int32_t if_expr,
// 3) nested parallel regions
if (OMP_UNLIKELY(!if_expr || state::HasThreadState ||
(config::mayUseNestedParallelism() && icv::Level))) {
// OpenMP 6.0 12.1.2 requires the num_threads 'strict' modifier to also have
// effect when parallel execution is disabled by a corresponding if clause
// attached to the parallel directive.
if (nt_strict && num_threads > 1)
numThreadsStrictError(nt_strict, nt_severity, nt_message, num_threads, 1);
state::DateEnvironmentRAII DERAII(ident);
++icv::Level;
invokeMicrotask(TId, 0, fn, args, nargs);
Expand All @@ -169,12 +198,14 @@ __kmpc_parallel_51(IdentTy *ident, int32_t, int32_t if_expr,
// This was moved to its own routine so it could be called directly
// in certain situations to avoid resource consumption of unused
// logic in parallel_51.
__kmpc_parallel_spmd(ident, num_threads, fn, args, nargs);
__kmpc_parallel_spmd(ident, num_threads, fn, args, nargs, nt_strict,
nt_severity, nt_message);

return;
}

uint32_t NumThreads = determineNumberOfThreads(num_threads);
uint32_t NumThreads =
determineNumberOfThreads(num_threads, nt_strict, nt_severity, nt_message);
uint32_t MaxTeamThreads = mapping::getMaxTeamThreads();
uint32_t PTeamSize = NumThreads == MaxTeamThreads ? 0 : NumThreads;

Expand Down Expand Up @@ -277,6 +308,16 @@ __kmpc_parallel_51(IdentTy *ident, int32_t, int32_t if_expr,
__kmpc_end_sharing_variables();
}

[[clang::always_inline]] void __kmpc_parallel_60(
IdentTy *ident, int32_t id, int32_t if_expr, int32_t num_threads,
int proc_bind, void *fn, void *wrapper_fn, void **args, int64_t nargs,
int32_t nt_strict = false, int32_t nt_severity = severity_fatal,
const char *nt_message = nullptr) {
return __kmpc_parallel_51(ident, id, if_expr, num_threads, proc_bind, fn,
wrapper_fn, args, nargs, nt_strict, nt_severity,
nt_message);
}

[[clang::noinline]] bool __kmpc_kernel_parallel(ParallelRegionFnTy *WorkFn) {
// Work function and arguments for L1 parallel region.
*WorkFn = state::ParallelRegionFn;
Expand Down
1 change: 1 addition & 0 deletions openmp/runtime/src/kmp.h
Original file line number Diff line number Diff line change
Expand Up @@ -4629,6 +4629,7 @@ extern void (*kmp_target_sync_cb)(ident_t *loc_ref, int gtid,
#endif // ENABLE_LIBOMPTARGET

// Support for error directive
// See definition in offload/DeviceRTL DeviceTypes.h
typedef enum kmp_severity_t {
severity_warning = 1,
severity_fatal = 2
Expand Down
Loading