diff --git a/openmp/runtime/src/kmp_runtime.cpp b/openmp/runtime/src/kmp_runtime.cpp index 3e5d671cb7a48..4aa810bd411ca 100644 --- a/openmp/runtime/src/kmp_runtime.cpp +++ b/openmp/runtime/src/kmp_runtime.cpp @@ -8516,60 +8516,24 @@ void __kmp_aux_set_library(enum library_type arg) { } } -/* Getting team information common for all team API */ -// Returns NULL if not in teams construct -static kmp_team_t *__kmp_aux_get_team_info(int &teams_serialized) { - kmp_info_t *thr = __kmp_entry_thread(); - teams_serialized = 0; - if (thr->th.th_teams_microtask) { - kmp_team_t *team = thr->th.th_team; - int tlevel = thr->th.th_teams_level; // the level of the teams construct - int ii = team->t.t_level; - teams_serialized = team->t.t_serialized; - int level = tlevel + 1; - KMP_DEBUG_ASSERT(ii >= tlevel); - while (ii > level) { - for (teams_serialized = team->t.t_serialized; - (teams_serialized > 0) && (ii > level); teams_serialized--, ii--) { - } - if (team->t.t_serialized && (!teams_serialized)) { - team = team->t.t_parent; - continue; - } - if (ii > level) { - team = team->t.t_parent; - ii--; - } - } - return team; - } - return NULL; -} - int __kmp_aux_get_team_num() { - int serialized; - kmp_team_t *team = __kmp_aux_get_team_info(serialized); - if (team) { - if (serialized > 1) { - return 0; // teams region is serialized ( 1 team of 1 thread ). - } else { - return team->t.t_master_tid; - } - } - return 0; + auto *team = __kmp_entry_thread()->th.th_team; + while (team && team->t.t_parent && + team->t.t_parent->t.t_pkfn != (microtask_t)__kmp_teams_master) + team = team->t.t_parent; + if (!team || !team->t.t_parent) + return 0; + return team->t.t_master_tid; } int __kmp_aux_get_num_teams() { - int serialized; - kmp_team_t *team = __kmp_aux_get_team_info(serialized); - if (team) { - if (serialized > 1) { - return 1; - } else { - return team->t.t_parent->t.t_nproc; - } - } - return 1; + auto *team = __kmp_entry_thread()->th.th_team; + while (team && team->t.t_parent && + team->t.t_parent->t.t_pkfn != (microtask_t)__kmp_teams_master) + team = team->t.t_parent; + if (!team || !team->t.t_parent) + return 1; + return team->t.t_parent->t.t_nproc; } /* ------------------------------------------------------------------------ */ diff --git a/openmp/runtime/src/kmp_sched.cpp b/openmp/runtime/src/kmp_sched.cpp index 2b1bb6f595f9a..3ae08cc899478 100644 --- a/openmp/runtime/src/kmp_sched.cpp +++ b/openmp/runtime/src/kmp_sched.cpp @@ -497,7 +497,6 @@ static void __kmp_dist_for_static_init(ident_t *loc, kmp_int32 gtid, kmp_uint32 team_id; kmp_uint32 nteams; UT trip_count; - kmp_team_t *team; kmp_info_t *th; KMP_DEBUG_ASSERT(plastiter && plower && pupper && pupperDist && pstride); @@ -540,17 +539,9 @@ static void __kmp_dist_for_static_init(ident_t *loc, kmp_int32 gtid, tid = __kmp_tid_from_gtid(gtid); th = __kmp_threads[gtid]; nth = th->th.th_team_nproc; - team = th->th.th_team; KMP_DEBUG_ASSERT(th->th.th_teams_microtask); // we are in the teams construct - // skip optional serialized teams to prevent this from using the wrong teams - // information when called after __kmp_serialized_parallel - // TODO: make __kmp_serialized_parallel eventually call __kmp_fork_in_teams - // to address this edge case - while (team->t.t_parent && team->t.t_serialized) - team = team->t.t_parent; - nteams = th->th.th_teams_size.nteams; - team_id = team->t.t_master_tid; - KMP_DEBUG_ASSERT(nteams == (kmp_uint32)team->t.t_parent->t.t_nproc); + nteams = __kmp_aux_get_num_teams(); + team_id = __kmp_aux_get_team_num(); // compute global trip count if (incr == 1) { diff --git a/openmp/runtime/test/teams/teams_parallel_if.c b/openmp/runtime/test/teams/teams_parallel_if.c new file mode 100644 index 0000000000000..82900f740cdaa --- /dev/null +++ b/openmp/runtime/test/teams/teams_parallel_if.c @@ -0,0 +1,81 @@ +// RUN: %libomp-compile -fopenmp-version=52 && %libomp-run + +#include +#include +#include + +typedef struct { + int team_num; + int thread_num; +} omp_id_t; + +/// Test if each worker threads can retrieve correct icv values. +void test_api(int nteams, int nthreads, int par_if) { + int expected_nteams = nteams; + int expected_nthreads = par_if ? nthreads : 1; + int expected_size = expected_nteams * expected_nthreads; + omp_id_t *expected = (omp_id_t *)malloc(expected_size * sizeof(omp_id_t)); + omp_id_t *observed = (omp_id_t *)malloc(expected_size * sizeof(omp_id_t)); + + for (int i = 0; i < expected_size; i++) { + expected[i].team_num = i / expected_nthreads; + expected[i].thread_num = i % expected_nthreads; + observed[i].team_num = -1; + observed[i].thread_num = -1; + } + +#pragma omp teams num_teams(nteams) +#pragma omp parallel num_threads(nthreads) if (par_if) + { + omp_id_t id = {omp_get_team_num(), omp_get_thread_num()}; + if (omp_get_num_teams() == expected_nteams && + omp_get_num_threads() == expected_nthreads && id.team_num >= 0 && + id.team_num < expected_nteams && id.thread_num >= 0 && + id.thread_num < expected_nthreads) { + int flat_id = id.thread_num + id.team_num * expected_nthreads; + observed[flat_id] = id; + } + } + + for (int i = 0; i < expected_size; i++) { + if (expected[i].team_num != observed[i].team_num || + expected[i].thread_num != observed[i].thread_num) { + printf("failed at nteams=%d, nthreads=%d, par_if=%d\n", nteams, nthreads, + par_if); + exit(EXIT_FAILURE); + } + } +} + +/// Test if __kmpc_dist_for_static_init works correctly. +void test_dist(int nteams, int nthreads, int par_if) { + int ub = 1000; + int index_sum_expected = ub * (ub + 1) / 2; + int index_sum = 0; +#pragma omp teams distribute parallel for num_teams(nteams) \ + num_threads(nthreads) if (par_if) + for (int i = 1; i <= ub; i++) +#pragma omp atomic update + index_sum += i; + + if (index_sum != index_sum_expected) { + printf("failed at nteams=%d, nthreads=%d, par_if=%d\n", nteams, nthreads, + par_if); + exit(EXIT_FAILURE); + } +} + +int main() { + for (int par_if = 0; par_if < 2; par_if++) { + for (int nteams = 1; nteams <= 16; nteams++) { + for (int nthreads = 1; nthreads <= 16; nthreads++) { + if (omp_get_max_threads() < nteams * nthreads) + continue; // make sure requested resources are granted + test_api(nteams, nthreads, par_if); + test_dist(nteams, nthreads, par_if); + } + } + } + printf("passed\n"); + return EXIT_SUCCESS; +}