@@ -66,76 +66,102 @@ int MPIDI_UCX_init_worker(int vci)
6666 goto fn_exit ;
6767}
6868
69- static int initial_address_exchange (void )
69+ #define UCX_AV_INSERT (av , lpid , name ) \
70+ do { \
71+ if (MPIDI_UCX_AV(av).dest[0][0] == NULL) { \
72+ ucs_status_t ucx_status; \
73+ ucp_ep_params_t ep_params; \
74+ ep_params.field_mask = UCP_EP_PARAM_FIELD_REMOTE_ADDRESS; \
75+ ep_params.address = (ucp_address_t *) (name); \
76+ ucx_status = ucp_ep_create(MPIDI_UCX_global.ctx[0].worker, &ep_params, &MPIDI_UCX_AV(av).dest[0][0]); \
77+ MPIDI_UCX_CHK_STATUS(ucx_status); \
78+ MPIDIU_upidhash_add(ep_params.address, addrnamelen, lpid); \
79+ } \
80+ } while (0)
81+
82+ int MPIDI_UCX_comm_addr_exchange (MPIR_Comm * comm )
7083{
7184 int mpi_errno = MPI_SUCCESS ;
72- ucs_status_t ucx_status ;
73- MPIR_Comm * init_comm = NULL ;
74-
75- void * table ;
76- int recv_bc_len ;
77- int size = MPIR_Process .size ;
78- int rank = MPIR_Process .rank ;
79- mpi_errno = MPIDU_bc_table_create (rank , size , MPIR_Process .node_map ,
80- MPIDI_UCX_global .ctx [0 ].if_address ,
81- (int ) MPIDI_UCX_global .ctx [0 ].addrname_len , FALSE,
82- MPIR_CVAR_CH4_ROOTS_ONLY_PMI , & table , & recv_bc_len );
83- MPIR_ERR_CHECK (mpi_errno );
85+ MPIR_CHKLMEM_DECL ();
8486
85- ucp_ep_params_t ep_params ;
86- if (MPIR_CVAR_CH4_ROOTS_ONLY_PMI ) {
87- int * node_roots = MPIR_Process .node_root_map ;
88- int num_nodes = MPIR_Process .num_nodes ;
89- int * rank_map ;
87+ /* only comm_world for now */
88+ MPIR_Assert (comm == MPIR_Process .comm_world );
9089
91- mpi_errno = MPIDI_create_init_comm (& init_comm );
92- MPIR_ERR_CHECK (mpi_errno );
90+ MPIR_Assert (comm -> attr & MPIR_COMM_ATTR__HIERARCHY );
9391
94- for (int i = 0 ; i < num_nodes ; i ++ ) {
95- ep_params .field_mask = UCP_EP_PARAM_FIELD_REMOTE_ADDRESS ;
96- ep_params .address = (ucp_address_t * ) ((char * ) table + i * recv_bc_len );
97- ucx_status =
98- ucp_ep_create (MPIDI_UCX_global .ctx [0 ].worker , & ep_params ,
99- & MPIDI_UCX_AV (MPIDIU_lpid_to_av (node_roots [i ])).dest [0 ][0 ]);
100- MPIDI_UCX_CHK_STATUS (ucx_status );
101- MPIDIU_upidhash_add (ep_params .address , recv_bc_len , node_roots [i ]);
102- }
103- mpi_errno = MPIDU_bc_allgather (init_comm , MPIDI_UCX_global .ctx [0 ].if_address ,
104- (int ) MPIDI_UCX_global .ctx [0 ].addrname_len , FALSE,
105- (void * * ) & table , & rank_map , & recv_bc_len );
92+ char * addrname = (void * ) MPIDI_UCX_global .ctx [0 ].if_address ;
93+ int addrnamelen = MPID_MAX_BC_SIZE ; /* 4096 */
94+ MPIR_Assert (MPIDI_UCX_global .ctx [0 ].addrname_len <= addrnamelen );
95+
96+ int local_rank = comm -> local_rank ;
97+ int external_size = comm -> num_external ;
98+
99+ if (external_size == 1 ) {
100+ /* skip root addrexch if we are the only node */
101+ goto all_addrexch ;
102+ }
103+
104+ /* PMI allgather over node roots and av_insert to activate node_roots_comm */
105+ if (local_rank == 0 ) {
106+ char * roots_names ;
107+ MPIR_CHKLMEM_MALLOC (roots_names , external_size * addrnamelen );
108+
109+ MPIR_PMI_DOMAIN domain = MPIR_PMI_DOMAIN_NODE_ROOTS ;
110+ mpi_errno = MPIR_pmi_allgather (addrname , addrnamelen , roots_names , addrnamelen , domain );
106111 MPIR_ERR_CHECK (mpi_errno );
107112
108- /* insert new addresses, skipping over node roots */
109- for (int i = 0 ; i < MPIR_Process .size ; i ++ ) {
110- if (rank_map [i ] >= 0 ) {
111- ep_params .field_mask = UCP_EP_PARAM_FIELD_REMOTE_ADDRESS ;
112- ep_params .address = (ucp_address_t * ) ((char * ) table + rank_map [i ] * recv_bc_len );
113- ucx_status = ucp_ep_create (MPIDI_UCX_global .ctx [0 ].worker , & ep_params ,
114- & MPIDI_UCX_AV (MPIDIU_lpid_to_av (i )).dest [0 ][0 ]);
115- MPIDI_UCX_CHK_STATUS (ucx_status );
116- MPIDIU_upidhash_add (ep_params .address , recv_bc_len , i );
117- }
113+ /* insert av and activate node_roots_comm */
114+ MPIR_Comm * node_roots_comm = MPIR_Comm_get_node_roots_comm (comm );
115+ for (int i = 0 ; i < external_size ; i ++ ) {
116+ char * p = (char * ) roots_names + i * addrnamelen ;
117+ MPIDI_av_entry_t * av = MPIDIU_comm_rank_to_av (node_roots_comm , i );
118+ MPIR_Lpid lpid = MPIR_comm_rank_to_lpid (node_roots_comm , i );
119+ UCX_AV_INSERT (av , lpid , p );
118120 }
119- mpi_errno = MPIDU_bc_table_destroy ();
120- MPIR_ERR_CHECK (mpi_errno );
121121 } else {
122- for (int i = 0 ; i < size ; i ++ ) {
123- ep_params .field_mask = UCP_EP_PARAM_FIELD_REMOTE_ADDRESS ;
124- ep_params .address = (ucp_address_t * ) ((char * ) table + i * recv_bc_len );
125- ucx_status =
126- ucp_ep_create (MPIDI_UCX_global .ctx [0 ].worker , & ep_params ,
127- & MPIDI_UCX_AV (MPIDIU_lpid_to_av (i )).dest [0 ][0 ]);
128- MPIDI_UCX_CHK_STATUS (ucx_status );
129- MPIDIU_upidhash_add (ep_params .address , recv_bc_len , i );
130- }
131- mpi_errno = MPIDU_bc_table_destroy ();
122+ /* just for the PMI_Barrier */
123+ MPIR_PMI_DOMAIN domain = MPIR_PMI_DOMAIN_NODE_ROOTS ;
124+ mpi_errno = MPIR_pmi_allgather (addrname , addrnamelen , NULL , addrnamelen , domain );
132125 MPIR_ERR_CHECK (mpi_errno );
133126 }
134127
135- fn_exit :
136- if (init_comm && !mpi_errno ) {
137- MPIDI_destroy_init_comm (& init_comm );
128+ all_addrexch :
129+ if (external_size == comm -> local_size ) {
130+ /* if no local, we are done. */
131+ goto fn_exit ;
132+ }
133+
134+ /* -- rest of the addr exchange over node_code and node_roots_comm -- */
135+ /* since the orders will be rearranged by nodes, let's echange rank along with name */
136+ struct rankname {
137+ int rank ;
138+ char name [];
139+ };
140+ int rankname_len = sizeof (int ) + addrnamelen ;
141+
142+ struct rankname * my_rankname , * all_ranknames ;
143+ MPIR_CHKLMEM_MALLOC (my_rankname , rankname_len );
144+ MPIR_CHKLMEM_MALLOC (all_ranknames , comm -> local_size * rankname_len );
145+
146+ my_rankname -> rank = comm -> rank ;
147+ memcpy (my_rankname -> name , addrname , addrnamelen );
148+
149+ /* Use an smp algorithm explicitly that only require a working node_comm and node_roots_comm. */
150+ mpi_errno = MPIR_Allgather_intra_smp_no_order (my_rankname , rankname_len , MPIR_BYTE_INTERNAL ,
151+ all_ranknames , rankname_len , MPIR_BYTE_INTERNAL ,
152+ comm , MPIR_ERR_NONE );
153+ MPIR_ERR_CHECK (mpi_errno );
154+
155+ /* create av, skipping existing entries */
156+ for (int i = 0 ; i < comm -> local_size ; i ++ ) {
157+ struct rankname * p = (struct rankname * ) ((char * ) all_ranknames + i * rankname_len );
158+ MPIDI_av_entry_t * av = MPIDIU_comm_rank_to_av (comm , p -> rank );
159+ MPIR_Lpid lpid = MPIR_comm_rank_to_lpid (comm , p -> rank );
160+ UCX_AV_INSERT (av , lpid , p -> name );
138161 }
162+
163+ fn_exit :
164+ MPIR_CHKLMEM_FREEALL ();
139165 return mpi_errno ;
140166 fn_fail :
141167 goto fn_exit ;
@@ -210,7 +236,7 @@ int MPIDI_UCX_init_world(void)
210236 mpi_errno = MPIDI_UCX_init_worker (0 );
211237 MPIR_ERR_CHECK (mpi_errno );
212238
213- mpi_errno = initial_address_exchange ( );
239+ mpi_errno = MPIDI_UCX_comm_addr_exchange ( MPIR_Process . comm_world );
214240 MPIR_ERR_CHECK (mpi_errno );
215241
216242 ucx_initialized = true;
0 commit comments