Skip to content

Commit 793af09

Browse files
committed
[OpenMP] Simplify accessing num-teams and team-num
We found an issue with accessing correct number of teams and team number when the enclosing region is serialized due to use of if clause. It appears that the existing method is not able to handle such cases, so this change proposes a simpler way of accessing the team struct bound to the implicit task invoked by each OpenMP team in the league.
1 parent 1e89a76 commit 793af09

File tree

3 files changed

+95
-61
lines changed

3 files changed

+95
-61
lines changed

openmp/runtime/src/kmp_runtime.cpp

Lines changed: 12 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -8516,60 +8516,22 @@ void __kmp_aux_set_library(enum library_type arg) {
85168516
}
85178517
}
85188518

8519-
/* Getting team information common for all team API */
8520-
// Returns NULL if not in teams construct
8521-
static kmp_team_t *__kmp_aux_get_team_info(int &teams_serialized) {
8522-
kmp_info_t *thr = __kmp_entry_thread();
8523-
teams_serialized = 0;
8524-
if (thr->th.th_teams_microtask) {
8525-
kmp_team_t *team = thr->th.th_team;
8526-
int tlevel = thr->th.th_teams_level; // the level of the teams construct
8527-
int ii = team->t.t_level;
8528-
teams_serialized = team->t.t_serialized;
8529-
int level = tlevel + 1;
8530-
KMP_DEBUG_ASSERT(ii >= tlevel);
8531-
while (ii > level) {
8532-
for (teams_serialized = team->t.t_serialized;
8533-
(teams_serialized > 0) && (ii > level); teams_serialized--, ii--) {
8534-
}
8535-
if (team->t.t_serialized && (!teams_serialized)) {
8536-
team = team->t.t_parent;
8537-
continue;
8538-
}
8539-
if (ii > level) {
8540-
team = team->t.t_parent;
8541-
ii--;
8542-
}
8543-
}
8544-
return team;
8545-
}
8546-
return NULL;
8547-
}
8548-
85498519
int __kmp_aux_get_team_num() {
8550-
int serialized;
8551-
kmp_team_t *team = __kmp_aux_get_team_info(serialized);
8552-
if (team) {
8553-
if (serialized > 1) {
8554-
return 0; // teams region is serialized ( 1 team of 1 thread ).
8555-
} else {
8556-
return team->t.t_master_tid;
8557-
}
8558-
}
8559-
return 0;
8520+
auto *team = __kmp_entry_thread()->th.th_team;
8521+
while (team && team->t.t_parent &&
8522+
team->t.t_parent->t.t_pkfn != (microtask_t)__kmp_teams_master)
8523+
team = team->t.t_parent;
8524+
return team ? team->t.t_master_tid : 0;
85608525
}
85618526

85628527
int __kmp_aux_get_num_teams() {
8563-
int serialized;
8564-
kmp_team_t *team = __kmp_aux_get_team_info(serialized);
8565-
if (team) {
8566-
if (serialized > 1) {
8567-
return 1;
8568-
} else {
8569-
return team->t.t_parent->t.t_nproc;
8570-
}
8571-
}
8572-
return 1;
8528+
auto *team = __kmp_entry_thread()->th.th_team;
8529+
while (team && team->t.t_parent &&
8530+
team->t.t_parent->t.t_pkfn != (microtask_t)__kmp_teams_master)
8531+
team = team->t.t_parent;
8532+
if (!team || !team->t.t_parent)
8533+
return 1;
8534+
return team->t.t_parent->t.t_nproc;
85738535
}
85748536

85758537
/* ------------------------------------------------------------------------ */

openmp/runtime/src/kmp_sched.cpp

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -497,7 +497,6 @@ static void __kmp_dist_for_static_init(ident_t *loc, kmp_int32 gtid,
497497
kmp_uint32 team_id;
498498
kmp_uint32 nteams;
499499
UT trip_count;
500-
kmp_team_t *team;
501500
kmp_info_t *th;
502501

503502
KMP_DEBUG_ASSERT(plastiter && plower && pupper && pupperDist && pstride);
@@ -540,17 +539,9 @@ static void __kmp_dist_for_static_init(ident_t *loc, kmp_int32 gtid,
540539
tid = __kmp_tid_from_gtid(gtid);
541540
th = __kmp_threads[gtid];
542541
nth = th->th.th_team_nproc;
543-
team = th->th.th_team;
544542
KMP_DEBUG_ASSERT(th->th.th_teams_microtask); // we are in the teams construct
545-
// skip optional serialized teams to prevent this from using the wrong teams
546-
// information when called after __kmp_serialized_parallel
547-
// TODO: make __kmp_serialized_parallel eventually call __kmp_fork_in_teams
548-
// to address this edge case
549-
while (team->t.t_parent && team->t.t_serialized)
550-
team = team->t.t_parent;
551-
nteams = th->th.th_teams_size.nteams;
552-
team_id = team->t.t_master_tid;
553-
KMP_DEBUG_ASSERT(nteams == (kmp_uint32)team->t.t_parent->t.t_nproc);
543+
nteams = __kmp_aux_get_num_teams();
544+
team_id = __kmp_aux_get_team_num();
554545

555546
// compute global trip count
556547
if (incr == 1) {
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
// RUN: %libomp-compile -fopenmp-version=52 && %libomp-run
2+
3+
#include <stdio.h>
4+
#include <stdlib.h>
5+
#include <omp.h>
6+
7+
typedef struct {
8+
int team_num;
9+
int thread_num;
10+
} omp_id_t;
11+
12+
/// Test if each worker threads can retrieve correct icv values.
13+
void test_api(int nteams, int nthreads, int par_if) {
14+
int expected_nteams = nteams;
15+
int expected_nthreads = par_if ? nthreads : 1;
16+
int expected_size = expected_nteams * expected_nthreads;
17+
omp_id_t *expected = (omp_id_t *)malloc(expected_size * sizeof(omp_id_t));
18+
omp_id_t *observed = (omp_id_t *)malloc(expected_size * sizeof(omp_id_t));
19+
20+
for (int i = 0; i < expected_size; i++) {
21+
expected[i].team_num = i / expected_nthreads;
22+
expected[i].thread_num = i % expected_nthreads;
23+
observed[i].team_num = -1;
24+
observed[i].thread_num = -1;
25+
}
26+
27+
#pragma omp teams num_teams(nteams)
28+
#pragma omp parallel num_threads(nthreads) if(par_if)
29+
{
30+
omp_id_t id = {omp_get_team_num(), omp_get_thread_num()};
31+
if (omp_get_num_teams() == expected_nteams &&
32+
omp_get_num_threads() == expected_nthreads &&
33+
id.team_num >= 0 && id.team_num < expected_nteams &&
34+
id.thread_num >= 0 && id.thread_num < expected_nthreads) {
35+
int flat_id = id.thread_num + id.team_num * expected_nthreads;
36+
observed[flat_id] = id;
37+
}
38+
}
39+
40+
for (int i = 0; i < expected_size; i++) {
41+
if (expected[i].team_num != observed[i].team_num ||
42+
expected[i].thread_num != observed[i].thread_num) {
43+
printf("failed at nteams=%d, nthreads=%d, par_if=%d\n",
44+
nteams, nthreads, par_if);
45+
exit(EXIT_FAILURE);
46+
}
47+
}
48+
}
49+
50+
/// Test if __kmpc_dist_for_static_init works correctly.
51+
void test_dist(int nteams, int nthreads, int par_if) {
52+
int ub = 1000;
53+
int index_sum_expected = ub * (ub + 1) / 2;
54+
int index_sum = 0;
55+
#pragma omp teams distribute parallel for num_teams(nteams) \
56+
num_threads(nthreads) if(par_if)
57+
for (int i = 1; i <= ub; i++)
58+
#pragma omp atomic update
59+
index_sum += i;
60+
61+
if (index_sum != index_sum_expected) {
62+
printf("failed at nteams=%d, nthreads=%d, par_if=%d\n", nteams, nthreads,
63+
par_if);
64+
exit(EXIT_FAILURE);
65+
}
66+
}
67+
68+
int main() {
69+
for (int par_if = 0; par_if < 2; par_if++) {
70+
for (int nteams = 1; nteams <= 16; nteams++) {
71+
for (int nthreads = 1; nthreads <= 16; nthreads++) {
72+
if (omp_get_max_threads() < nteams * nthreads)
73+
continue; // make sure requested resources are granted
74+
test_api(nteams, nthreads, par_if);
75+
test_dist(nteams, nthreads, par_if);
76+
}
77+
}
78+
}
79+
printf("passed\n");
80+
return EXIT_SUCCESS;
81+
}

0 commit comments

Comments
 (0)