Skip to content

Commit 4b4239b

Browse files
committed
ADI/comm: exchange context_id in MPID_Intercomm_exchange
We can easily exchange the context_id along with the rest of the remote info rather than do it in a separate step. We can determine is_low_group by comparing world namespace and world_rank entirely in the MPIR layer, thus no longer need it in MPID_Intercomm_exchange. Rename MPID_Intercomm_exchange_map to MPID_Intercomm_exchange to better reflect that it is not just exchanging maps.
1 parent 226ee9b commit 4b4239b

File tree

6 files changed

+99
-124
lines changed

6 files changed

+99
-124
lines changed

src/mpi/comm/comm_impl.c

Lines changed: 45 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -952,88 +952,84 @@ int MPIR_Comm_set_info_impl(MPIR_Comm * comm_ptr, MPIR_Info * info_ptr)
952952
goto fn_exit;
953953
}
954954

955+
/* arbitrarily determine which group is the low_group by comparing
956+
* world namespaces and world ranks */
957+
static int determine_low_group(MPIR_Lpid remote_lpid, bool * is_low_group_out)
958+
{
959+
int mpi_errno = MPI_SUCCESS;
960+
961+
int my_world_idx = 0;
962+
int my_world_rank = MPIR_Process.rank;
963+
int remote_world_idx = MPIR_LPID_WORLD_INDEX(remote_lpid);
964+
int remote_world_rank = MPIR_LPID_WORLD_RANK(remote_lpid);
965+
966+
if (my_world_idx == remote_world_idx) {
967+
/* same world, just compare world ranks */
968+
MPIR_Assert(my_world_idx == 0);
969+
*is_low_group_out = (my_world_rank < remote_world_rank);
970+
} else {
971+
/* different world, compare namespace */
972+
int cmp_result = strncmp(MPIR_Worlds[my_world_idx].namespace,
973+
MPIR_Worlds[remote_world_idx].namespace,
974+
MPIR_NAMESPACE_MAX);
975+
MPIR_Assert(cmp_result != 0);
976+
if (cmp_result < 0)
977+
*is_low_group_out = false;
978+
else
979+
*is_low_group_out = true;
980+
}
981+
982+
return mpi_errno;
983+
}
984+
955985
int MPIR_Intercomm_create_impl(MPIR_Comm * local_comm_ptr, int local_leader,
956986
MPIR_Comm * peer_comm_ptr, int remote_leader, int tag,
957987
MPIR_Comm ** new_intercomm_ptr)
958988
{
959989
int mpi_errno = MPI_SUCCESS;
960-
int final_context_id, recvcontext_id;
961990
int remote_size = 0;
962991
MPIR_Lpid *remote_lpids = NULL;
963-
int comm_info[3];
964-
int is_low_group = 0;
965992
MPIR_Session *session_ptr = local_comm_ptr->session_ptr;
966993

967994
MPIR_FUNC_ENTER;
968995

969-
/* Shift tag into the tagged coll space */
970-
tag |= MPIR_TAG_COLL_BIT;
971-
972-
mpi_errno = MPID_Intercomm_exchange_map(local_comm_ptr, local_leader,
973-
peer_comm_ptr, remote_leader,
974-
&remote_size, &remote_lpids, &is_low_group);
975-
MPIR_ERR_CHECK(mpi_errno);
976-
977996
/*
978997
* Create the contexts. Each group will have a context for sending
979998
* to the other group. All processes must be involved. Because
980999
* we know that the local and remote groups are disjoint, this
9811000
* step will complete
9821001
*/
983-
MPL_DBG_MSG_FMT(MPIR_DBG_COMM, VERBOSE,
984-
(MPL_DBG_FDEST, "About to get contextid (local_size=%d) on rank %d",
985-
local_comm_ptr->local_size, local_comm_ptr->rank));
9861002
/* In the multi-threaded case, MPIR_Get_contextid_sparse assumes that the
9871003
* calling routine already holds the single critical section */
9881004
/* TODO: Make sure this is tag-safe */
1005+
int recvcontext_id;
9891006
mpi_errno = MPIR_Get_contextid_sparse(local_comm_ptr, &recvcontext_id, FALSE);
9901007
MPIR_ERR_CHECK(mpi_errno);
9911008
MPIR_Assert(recvcontext_id != 0);
992-
MPL_DBG_MSG_FMT(MPIR_DBG_COMM, VERBOSE, (MPL_DBG_FDEST, "Got contextid=%d", recvcontext_id));
993-
994-
/* Leaders can now swap context ids and then broadcast the value
995-
* to the local group of processes */
996-
if (local_comm_ptr->rank == local_leader) {
997-
int remote_context_id;
9981009

999-
mpi_errno =
1000-
MPIC_Sendrecv(&recvcontext_id, 1, MPIR_CONTEXT_ID_T_DATATYPE, remote_leader, tag,
1001-
&remote_context_id, 1, MPIR_CONTEXT_ID_T_DATATYPE, remote_leader, tag,
1002-
peer_comm_ptr, MPI_STATUS_IGNORE, MPIR_ERR_NONE);
1003-
MPIR_ERR_CHECK(mpi_errno);
1004-
1005-
final_context_id = remote_context_id;
1010+
/* Shift tag into the tagged coll space */
1011+
tag |= MPIR_TAG_COLL_BIT;
10061012

1007-
/* Now, send all of our local processes the remote_lpids,
1008-
* along with the final context id */
1009-
comm_info[0] = final_context_id;
1010-
MPL_DBG_MSG(MPIR_DBG_COMM, VERBOSE, "About to bcast on local_comm");
1011-
mpi_errno = MPIR_Bcast(comm_info, 1, MPIR_INT_INTERNAL, local_leader,
1012-
local_comm_ptr, MPIR_ERR_NONE);
1013-
MPIR_ERR_CHECK(mpi_errno);
1014-
MPL_DBG_MSG_D(MPIR_DBG_COMM, VERBOSE, "end of bcast on local_comm of size %d",
1015-
local_comm_ptr->local_size);
1016-
} else {
1017-
/* we're the other processes */
1018-
MPL_DBG_MSG(MPIR_DBG_COMM, VERBOSE, "About to receive bcast on local_comm");
1019-
mpi_errno = MPIR_Bcast(comm_info, 1, MPIR_INT_INTERNAL, local_leader,
1020-
local_comm_ptr, MPIR_ERR_NONE);
1021-
MPIR_ERR_CHECK(mpi_errno);
1013+
int remote_context_id;
1014+
mpi_errno = MPID_Intercomm_exchange(local_comm_ptr, local_leader,
1015+
peer_comm_ptr, remote_leader, tag,
1016+
recvcontext_id, &remote_context_id,
1017+
&remote_size, &remote_lpids);
1018+
MPIR_ERR_CHECK(mpi_errno);
10221019

1023-
/* Extract the context and group sign information */
1024-
final_context_id = comm_info[0];
1025-
}
1020+
bool is_low_group;
1021+
mpi_errno = determine_low_group(remote_lpids[0], &is_low_group);
1022+
MPIR_ERR_CHECK(mpi_errno);
10261023

10271024
/* At last, we now have the information that we need to build the
10281025
* intercommunicator */
10291026

10301027
/* All processes in the local_comm now build the communicator */
10311028

10321029
mpi_errno = MPIR_Comm_create(new_intercomm_ptr);
1033-
if (mpi_errno)
1034-
goto fn_fail;
1030+
MPIR_ERR_CHECK(mpi_errno);
10351031

1036-
(*new_intercomm_ptr)->context_id = final_context_id;
1032+
(*new_intercomm_ptr)->context_id = remote_context_id;
10371033
(*new_intercomm_ptr)->recvcontext_id = recvcontext_id;
10381034
(*new_intercomm_ptr)->remote_size = remote_size;
10391035
(*new_intercomm_ptr)->local_size = local_comm_ptr->local_size;
@@ -1055,6 +1051,7 @@ int MPIR_Intercomm_create_impl(MPIR_Comm * local_comm_ptr, int local_leader,
10551051
/* construct remote_group */
10561052
mpi_errno = MPIR_Group_create_map(remote_size, MPI_UNDEFINED, session_ptr, remote_lpids,
10571053
&(*new_intercomm_ptr)->remote_group);
1054+
MPIR_ERR_CHECK(mpi_errno);
10581055

10591056
MPIR_Comm_set_session_ptr(*new_intercomm_ptr, session_ptr);
10601057

src/mpid/ch3/include/mpidpost.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -186,10 +186,10 @@ int MPIDI_GPID_ToLpidArray( int size, MPIDI_Gpid gpid[], MPIR_Lpid lpid[] );
186186
int MPIDI_PG_ForwardPGInfo( MPIR_Comm *peer_ptr, MPIR_Comm *comm_ptr,
187187
int nPGids, const MPIDI_Gpid gpids[],
188188
int root );
189-
int MPID_Intercomm_exchange_map( MPIR_Comm *local_comm_ptr, int local_leader,
190-
MPIR_Comm *peer_comm_ptr, int remote_leader,
191-
int *remote_size, MPIR_Lpid **remote_lpids,
192-
int *is_low_group);
189+
int MPID_Intercomm_exchange(MPIR_Comm *local_comm_ptr, int local_leader,
190+
MPIR_Comm *peer_comm_ptr, int remote_leader,
191+
int tag, int context_id, int *remote_context_id,
192+
int *remote_size, MPIR_Lpid **remote_lpids);
193193
int MPID_Create_intercomm_from_lpids( MPIR_Comm *newcomm_ptr,
194194
int size, const MPIR_Lpid lpids[] );
195195
int MPID_Comm_get_lpid(MPIR_Comm *comm_ptr, int idx, MPIR_Lpid *lpid_ptr, bool is_remote);

src/mpid/ch3/src/mpid_vc.c

Lines changed: 19 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -440,24 +440,20 @@ static int check_disjoint_lpids(MPIR_Lpid lpids1[], int n1, MPIR_Lpid lpids2[],
440440
#endif /* HAVE_ERROR_CHECKING */
441441

442442
/*@
443-
MPID_Intercomm_exchange_map - Exchange address mapping for intercomm creation.
443+
MPID_Intercomm_exchange - Exchange remote info for intercomm creation.
444444
@*/
445-
int MPID_Intercomm_exchange_map(MPIR_Comm *local_comm_ptr, int local_leader,
446-
MPIR_Comm *peer_comm_ptr, int remote_leader,
447-
int *remote_size, MPIR_Lpid **remote_lpids,
448-
int *is_low_group)
445+
int MPID_Intercomm_exchange(MPIR_Comm *local_comm_ptr, int local_leader,
446+
MPIR_Comm *peer_comm_ptr, int remote_leader, int tag,
447+
int context_id, int *remote_context_id,
448+
int *remote_size, MPIR_Lpid **remote_lpids)
449449
{
450450
int mpi_errno = MPI_SUCCESS;
451451
int singlePG;
452452
int local_size;
453453
MPIR_Lpid *local_lpids=0;
454454
MPIDI_Gpid *local_gpids=NULL, *remote_gpids=NULL;
455-
int comm_info[2];
456-
int cts_tag;
457455
MPIR_CHKLMEM_DECL();
458456

459-
cts_tag = 0 | MPIR_TAG_COLL_BIT;
460-
461457
if (local_comm_ptr->rank == local_leader) {
462458

463459
/* First, exchange the group information. If we were certain
@@ -471,13 +467,17 @@ int MPID_Intercomm_exchange_map(MPIR_Comm *local_comm_ptr, int local_leader,
471467
/* printf( "About to sendrecv in intercomm_create\n" );fflush(stdout);*/
472468
MPL_DBG_MSG_FMT(MPIDI_CH3_DBG_OTHER,VERBOSE,(MPL_DBG_FDEST,"rank %d sendrecv to rank %d", peer_comm_ptr->rank,
473469
remote_leader));
474-
mpi_errno = MPIC_Sendrecv( &local_size, 1, MPIR_INT_INTERNAL,
475-
remote_leader, cts_tag,
476-
remote_size, 1, MPIR_INT_INTERNAL,
477-
remote_leader, cts_tag,
478-
peer_comm_ptr, MPI_STATUS_IGNORE, MPIR_ERR_NONE );
470+
int local_ints[2] = {local_size, context_id};
471+
int remote_ints[2];
472+
mpi_errno = MPIC_Sendrecv(local_ints, 2, MPIR_INT_INTERNAL,
473+
remote_leader, tag,
474+
remote_ints, 2, MPIR_INT_INTERNAL,
475+
remote_leader, tag,
476+
peer_comm_ptr, MPI_STATUS_IGNORE, MPIR_ERR_NONE );
479477
MPIR_ERR_CHECK(mpi_errno);
480478

479+
*remote_size = remote_ints[0];
480+
*remote_context_id = remote_ints[1];
481481
MPL_DBG_MSG_FMT(MPIDI_CH3_DBG_OTHER,VERBOSE,(MPL_DBG_FDEST, "local size = %d, remote size = %d", local_size,
482482
*remote_size ));
483483
/* With this information, we can now send and receive the
@@ -492,9 +492,9 @@ int MPID_Intercomm_exchange_map(MPIR_Comm *local_comm_ptr, int local_leader,
492492

493493
/* Exchange the lpid arrays */
494494
mpi_errno = MPIC_Sendrecv( local_gpids, local_size*sizeof(MPIDI_Gpid), MPIR_BYTE_INTERNAL,
495-
remote_leader, cts_tag,
495+
remote_leader, tag,
496496
remote_gpids, (*remote_size)*sizeof(MPIDI_Gpid), MPIR_BYTE_INTERNAL,
497-
remote_leader, cts_tag, peer_comm_ptr,
497+
remote_leader, tag, peer_comm_ptr,
498498
MPI_STATUS_IGNORE, MPIR_ERR_NONE );
499499
MPIR_ERR_CHECK(mpi_errno);
500500

@@ -520,22 +520,18 @@ int MPID_Intercomm_exchange_map(MPIR_Comm *local_comm_ptr, int local_leader,
520520
}
521521
# endif /* HAVE_ERROR_CHECKING */
522522

523-
/* Make an arbitrary decision about which group of process is
524-
the low group. The LEADERS do this by comparing the
525-
local process ids of the 0th member of the two groups */
526-
(*is_low_group) = local_lpids[0] < (*remote_lpids)[0];
527-
528523
/* At this point, we're done with the local lpids; they'll
529524
be freed with the other local memory on exit */
530525

531526
} /* End of the first phase of the leader communication */
532527
/* Leaders can now swap context ids and then broadcast the value
533528
to the local group of processes */
529+
int comm_info[3];
534530
if (local_comm_ptr->rank == local_leader) {
535531
/* Now, send all of our local processes the remote_lpids,
536532
along with the final context id */
537533
comm_info[0] = *remote_size;
538-
comm_info[1] = *is_low_group;
534+
comm_info[1] = *remote_context_id;
539535
MPL_DBG_MSG(MPIDI_CH3_DBG_OTHER,VERBOSE,"About to bcast on local_comm");
540536
mpi_errno = MPIR_Bcast( comm_info, 2, MPIR_INT_INTERNAL, local_leader, local_comm_ptr, MPIR_ERR_NONE );
541537
MPIR_ERR_CHECK(mpi_errno);
@@ -559,7 +555,7 @@ int MPID_Intercomm_exchange_map(MPIR_Comm *local_comm_ptr, int local_leader,
559555
MPIR_ERR_CHECK(mpi_errno);
560556

561557
/* Extract the context and group sign information */
562-
*is_low_group = comm_info[1];
558+
*remote_context_id = comm_info[1];
563559
}
564560

565561
/* Finish up by giving the device the opportunity to update

src/mpid/ch4/include/mpidch4.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,10 @@ int MPID_Type_commit_hook(MPIR_Datatype *);
167167
int MPID_Type_free_hook(MPIR_Datatype *);
168168
int MPID_Op_commit_hook(MPIR_Op *);
169169
int MPID_Op_free_hook(MPIR_Op *);
170-
int MPID_Intercomm_exchange_map(MPIR_Comm *, int, MPIR_Comm *, int, int *, MPIR_Lpid **, int *);
170+
int MPID_Intercomm_exchange(MPIR_Comm * local_comm, int local_leader,
171+
MPIR_Comm * peer_comm, int remote_leader, int tag,
172+
int context_id, int *remote_context_id_out,
173+
int *remote_size_out, MPIR_Lpid ** remote_lpids_out);
171174
int MPID_Create_intercomm_from_lpids(MPIR_Comm *, int, const MPIR_Lpid[]);
172175
int MPID_Comm_commit_pre_hook(MPIR_Comm *);
173176
int MPID_Comm_free_hook(MPIR_Comm *);

0 commit comments

Comments
 (0)