Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 14 additions & 50 deletions openmp/runtime/src/kmp_runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8516,60 +8516,24 @@ void __kmp_aux_set_library(enum library_type arg) {
}
}

/* Getting team information common for all team API */
// Returns NULL if not in teams construct
static kmp_team_t *__kmp_aux_get_team_info(int &teams_serialized) {
kmp_info_t *thr = __kmp_entry_thread();
teams_serialized = 0;
if (thr->th.th_teams_microtask) {
kmp_team_t *team = thr->th.th_team;
int tlevel = thr->th.th_teams_level; // the level of the teams construct
int ii = team->t.t_level;
teams_serialized = team->t.t_serialized;
int level = tlevel + 1;
KMP_DEBUG_ASSERT(ii >= tlevel);
while (ii > level) {
for (teams_serialized = team->t.t_serialized;
(teams_serialized > 0) && (ii > level); teams_serialized--, ii--) {
}
if (team->t.t_serialized && (!teams_serialized)) {
team = team->t.t_parent;
continue;
}
if (ii > level) {
team = team->t.t_parent;
ii--;
}
}
return team;
}
return NULL;
}

int __kmp_aux_get_team_num() {
int serialized;
kmp_team_t *team = __kmp_aux_get_team_info(serialized);
if (team) {
if (serialized > 1) {
return 0; // teams region is serialized ( 1 team of 1 thread ).
} else {
return team->t.t_master_tid;
}
}
return 0;
auto *team = __kmp_entry_thread()->th.th_team;
while (team && team->t.t_parent &&
team->t.t_parent->t.t_pkfn != (microtask_t)__kmp_teams_master)
team = team->t.t_parent;
if (!team || !team->t.t_parent)
return 0;
return team->t.t_master_tid;
}

int __kmp_aux_get_num_teams() {
int serialized;
kmp_team_t *team = __kmp_aux_get_team_info(serialized);
if (team) {
if (serialized > 1) {
return 1;
} else {
return team->t.t_parent->t.t_nproc;
}
}
return 1;
auto *team = __kmp_entry_thread()->th.th_team;
while (team && team->t.t_parent &&
team->t.t_parent->t.t_pkfn != (microtask_t)__kmp_teams_master)
team = team->t.t_parent;
if (!team || !team->t.t_parent)
return 1;
return team->t.t_parent->t.t_nproc;
}

/* ------------------------------------------------------------------------ */
Expand Down
13 changes: 2 additions & 11 deletions openmp/runtime/src/kmp_sched.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -497,7 +497,6 @@ static void __kmp_dist_for_static_init(ident_t *loc, kmp_int32 gtid,
kmp_uint32 team_id;
kmp_uint32 nteams;
UT trip_count;
kmp_team_t *team;
kmp_info_t *th;

KMP_DEBUG_ASSERT(plastiter && plower && pupper && pupperDist && pstride);
Expand Down Expand Up @@ -540,17 +539,9 @@ static void __kmp_dist_for_static_init(ident_t *loc, kmp_int32 gtid,
tid = __kmp_tid_from_gtid(gtid);
th = __kmp_threads[gtid];
nth = th->th.th_team_nproc;
team = th->th.th_team;
KMP_DEBUG_ASSERT(th->th.th_teams_microtask); // we are in the teams construct
// skip optional serialized teams to prevent this from using the wrong teams
// information when called after __kmp_serialized_parallel
// TODO: make __kmp_serialized_parallel eventually call __kmp_fork_in_teams
// to address this edge case
while (team->t.t_parent && team->t.t_serialized)
team = team->t.t_parent;
nteams = th->th.th_teams_size.nteams;
team_id = team->t.t_master_tid;
KMP_DEBUG_ASSERT(nteams == (kmp_uint32)team->t.t_parent->t.t_nproc);
nteams = __kmp_aux_get_num_teams();
team_id = __kmp_aux_get_team_num();

// compute global trip count
if (incr == 1) {
Expand Down
81 changes: 81 additions & 0 deletions openmp/runtime/test/teams/teams_parallel_if.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
// RUN: %libomp-compile -fopenmp-version=52 && %libomp-run

#include <stdio.h>
#include <stdlib.h>
#include <omp.h>

typedef struct {
int team_num;
int thread_num;
} omp_id_t;

/// Test if each worker threads can retrieve correct icv values.
void test_api(int nteams, int nthreads, int par_if) {
int expected_nteams = nteams;
int expected_nthreads = par_if ? nthreads : 1;
int expected_size = expected_nteams * expected_nthreads;
omp_id_t *expected = (omp_id_t *)malloc(expected_size * sizeof(omp_id_t));
omp_id_t *observed = (omp_id_t *)malloc(expected_size * sizeof(omp_id_t));

for (int i = 0; i < expected_size; i++) {
expected[i].team_num = i / expected_nthreads;
expected[i].thread_num = i % expected_nthreads;
observed[i].team_num = -1;
observed[i].thread_num = -1;
}

#pragma omp teams num_teams(nteams)
#pragma omp parallel num_threads(nthreads) if (par_if)
{
omp_id_t id = {omp_get_team_num(), omp_get_thread_num()};
if (omp_get_num_teams() == expected_nteams &&
omp_get_num_threads() == expected_nthreads && id.team_num >= 0 &&
id.team_num < expected_nteams && id.thread_num >= 0 &&
id.thread_num < expected_nthreads) {
int flat_id = id.thread_num + id.team_num * expected_nthreads;
observed[flat_id] = id;
}
}

for (int i = 0; i < expected_size; i++) {
if (expected[i].team_num != observed[i].team_num ||
expected[i].thread_num != observed[i].thread_num) {
printf("failed at nteams=%d, nthreads=%d, par_if=%d\n", nteams, nthreads,
par_if);
exit(EXIT_FAILURE);
}
}
}

/// Test if __kmpc_dist_for_static_init works correctly.
void test_dist(int nteams, int nthreads, int par_if) {
int ub = 1000;
int index_sum_expected = ub * (ub + 1) / 2;
int index_sum = 0;
#pragma omp teams distribute parallel for num_teams(nteams) \
num_threads(nthreads) if (par_if)
for (int i = 1; i <= ub; i++)
#pragma omp atomic update
index_sum += i;

if (index_sum != index_sum_expected) {
printf("failed at nteams=%d, nthreads=%d, par_if=%d\n", nteams, nthreads,
par_if);
exit(EXIT_FAILURE);
}
}

int main() {
for (int par_if = 0; par_if < 2; par_if++) {
for (int nteams = 1; nteams <= 16; nteams++) {
for (int nthreads = 1; nthreads <= 16; nthreads++) {
if (omp_get_max_threads() < nteams * nthreads)
continue; // make sure requested resources are granted
test_api(nteams, nthreads, par_if);
test_dist(nteams, nthreads, par_if);
}
}
}
printf("passed\n");
return EXIT_SUCCESS;
}