Skip to content

Commit 69e6da6

Browse files
committed
ch4/ucx: add MPIDI_UCX_comm_addr_exchange
Do it the same way as MPIDI_OFI_comm_addr_exchange.
1 parent 58b01e3 commit 69e6da6

File tree

2 files changed

+85
-60
lines changed

2 files changed

+85
-60
lines changed

src/mpid/ch4/netmod/ucx/ucx_impl.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,7 @@ MPL_STATIC_INLINE_PREFIX bool MPIDI_UCX_is_reachable_target(int rank, MPIR_Win *
130130

131131
int MPIDI_UCX_init_world(void);
132132
int MPIDI_UCX_init_worker(int vci);
133+
int MPIDI_UCX_comm_addr_exchange(MPIR_Comm * comm);
133134

134135
/* am handler for message sent by ucp_am_send_nb */
135136
ucs_status_t MPIDI_UCX_am_handler(void *arg, void *data, size_t length, ucp_ep_h reply_ep,
@@ -150,6 +151,4 @@ void MPIDI_UCX_am_recv_callback_nbx(void *request, ucs_status_t status, size_t l
150151
void *user_data);
151152
#endif
152153

153-
int MPIDI_UCX_init_worker(int vci);
154-
155154
#endif /* UCX_IMPL_H_INCLUDED */

src/mpid/ch4/netmod/ucx/ucx_init.c

Lines changed: 84 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -66,76 +66,102 @@ int MPIDI_UCX_init_worker(int vci)
6666
goto fn_exit;
6767
}
6868

69-
static int initial_address_exchange(void)
69+
#define UCX_AV_INSERT(av, lpid, name) \
70+
do { \
71+
if (MPIDI_UCX_AV(av).dest[0][0] == NULL) { \
72+
ucs_status_t ucx_status; \
73+
ucp_ep_params_t ep_params; \
74+
ep_params.field_mask = UCP_EP_PARAM_FIELD_REMOTE_ADDRESS; \
75+
ep_params.address = (ucp_address_t *) (name); \
76+
ucx_status = ucp_ep_create(MPIDI_UCX_global.ctx[0].worker, &ep_params, &MPIDI_UCX_AV(av).dest[0][0]); \
77+
MPIDI_UCX_CHK_STATUS(ucx_status); \
78+
MPIDIU_upidhash_add(ep_params.address, addrnamelen, lpid); \
79+
} \
80+
} while (0)
81+
82+
int MPIDI_UCX_comm_addr_exchange(MPIR_Comm * comm)
7083
{
7184
int mpi_errno = MPI_SUCCESS;
72-
ucs_status_t ucx_status;
73-
MPIR_Comm *init_comm = NULL;
74-
75-
void *table;
76-
int recv_bc_len;
77-
int size = MPIR_Process.size;
78-
int rank = MPIR_Process.rank;
79-
mpi_errno = MPIDU_bc_table_create(rank, size, MPIR_Process.node_map,
80-
MPIDI_UCX_global.ctx[0].if_address,
81-
(int) MPIDI_UCX_global.ctx[0].addrname_len, FALSE,
82-
MPIR_CVAR_CH4_ROOTS_ONLY_PMI, &table, &recv_bc_len);
83-
MPIR_ERR_CHECK(mpi_errno);
85+
MPIR_CHKLMEM_DECL();
8486

85-
ucp_ep_params_t ep_params;
86-
if (MPIR_CVAR_CH4_ROOTS_ONLY_PMI) {
87-
int *node_roots = MPIR_Process.node_root_map;
88-
int num_nodes = MPIR_Process.num_nodes;
89-
int *rank_map;
87+
/* only comm_world for now */
88+
MPIR_Assert(comm == MPIR_Process.comm_world);
9089

91-
mpi_errno = MPIDI_create_init_comm(&init_comm);
92-
MPIR_ERR_CHECK(mpi_errno);
90+
MPIR_Assert(comm->attr & MPIR_COMM_ATTR__HIERARCHY);
9391

94-
for (int i = 0; i < num_nodes; i++) {
95-
ep_params.field_mask = UCP_EP_PARAM_FIELD_REMOTE_ADDRESS;
96-
ep_params.address = (ucp_address_t *) ((char *) table + i * recv_bc_len);
97-
ucx_status =
98-
ucp_ep_create(MPIDI_UCX_global.ctx[0].worker, &ep_params,
99-
&MPIDI_UCX_AV(MPIDIU_lpid_to_av(node_roots[i])).dest[0][0]);
100-
MPIDI_UCX_CHK_STATUS(ucx_status);
101-
MPIDIU_upidhash_add(ep_params.address, recv_bc_len, node_roots[i]);
102-
}
103-
mpi_errno = MPIDU_bc_allgather(init_comm, MPIDI_UCX_global.ctx[0].if_address,
104-
(int) MPIDI_UCX_global.ctx[0].addrname_len, FALSE,
105-
(void **) &table, &rank_map, &recv_bc_len);
92+
char *addrname = (void *) MPIDI_UCX_global.ctx[0].if_address;
93+
int addrnamelen = MPID_MAX_BC_SIZE; /* 4096 */
94+
MPIR_Assert(MPIDI_UCX_global.ctx[0].addrname_len <= addrnamelen);
95+
96+
int local_rank = comm->local_rank;
97+
int external_size = comm->num_external;
98+
99+
if (external_size == 1) {
100+
/* skip root addrexch if we are the only node */
101+
goto all_addrexch;
102+
}
103+
104+
/* PMI allgather over node roots and av_insert to activate node_roots_comm */
105+
if (local_rank == 0) {
106+
char *roots_names;
107+
MPIR_CHKLMEM_MALLOC(roots_names, external_size * addrnamelen);
108+
109+
MPIR_PMI_DOMAIN domain = MPIR_PMI_DOMAIN_NODE_ROOTS;
110+
mpi_errno = MPIR_pmi_allgather(addrname, addrnamelen, roots_names, addrnamelen, domain);
106111
MPIR_ERR_CHECK(mpi_errno);
107112

108-
/* insert new addresses, skipping over node roots */
109-
for (int i = 0; i < MPIR_Process.size; i++) {
110-
if (rank_map[i] >= 0) {
111-
ep_params.field_mask = UCP_EP_PARAM_FIELD_REMOTE_ADDRESS;
112-
ep_params.address = (ucp_address_t *) ((char *) table + rank_map[i] * recv_bc_len);
113-
ucx_status = ucp_ep_create(MPIDI_UCX_global.ctx[0].worker, &ep_params,
114-
&MPIDI_UCX_AV(MPIDIU_lpid_to_av(i)).dest[0][0]);
115-
MPIDI_UCX_CHK_STATUS(ucx_status);
116-
MPIDIU_upidhash_add(ep_params.address, recv_bc_len, i);
117-
}
113+
/* insert av and activate node_roots_comm */
114+
MPIR_Comm *node_roots_comm = MPIR_Comm_get_node_roots_comm(comm);
115+
for (int i = 0; i < external_size; i++) {
116+
char *p = (char *) roots_names + i * addrnamelen;
117+
MPIDI_av_entry_t *av = MPIDIU_comm_rank_to_av(node_roots_comm, i);
118+
MPIR_Lpid lpid = MPIR_comm_rank_to_lpid(node_roots_comm, i);
119+
UCX_AV_INSERT(av, lpid, p);
118120
}
119-
mpi_errno = MPIDU_bc_table_destroy();
120-
MPIR_ERR_CHECK(mpi_errno);
121121
} else {
122-
for (int i = 0; i < size; i++) {
123-
ep_params.field_mask = UCP_EP_PARAM_FIELD_REMOTE_ADDRESS;
124-
ep_params.address = (ucp_address_t *) ((char *) table + i * recv_bc_len);
125-
ucx_status =
126-
ucp_ep_create(MPIDI_UCX_global.ctx[0].worker, &ep_params,
127-
&MPIDI_UCX_AV(MPIDIU_lpid_to_av(i)).dest[0][0]);
128-
MPIDI_UCX_CHK_STATUS(ucx_status);
129-
MPIDIU_upidhash_add(ep_params.address, recv_bc_len, i);
130-
}
131-
mpi_errno = MPIDU_bc_table_destroy();
122+
/* just for the PMI_Barrier */
123+
MPIR_PMI_DOMAIN domain = MPIR_PMI_DOMAIN_NODE_ROOTS;
124+
mpi_errno = MPIR_pmi_allgather(addrname, addrnamelen, NULL, addrnamelen, domain);
132125
MPIR_ERR_CHECK(mpi_errno);
133126
}
134127

135-
fn_exit:
136-
if (init_comm && !mpi_errno) {
137-
MPIDI_destroy_init_comm(&init_comm);
128+
all_addrexch:
129+
if (external_size == comm->local_size) {
130+
/* if no local, we are done. */
131+
goto fn_exit;
132+
}
133+
134+
/* -- rest of the addr exchange over node_code and node_roots_comm -- */
135+
/* since the orders will be rearranged by nodes, let's echange rank along with name */
136+
struct rankname {
137+
int rank;
138+
char name[];
139+
};
140+
int rankname_len = sizeof(int) + addrnamelen;
141+
142+
struct rankname *my_rankname, *all_ranknames;
143+
MPIR_CHKLMEM_MALLOC(my_rankname, rankname_len);
144+
MPIR_CHKLMEM_MALLOC(all_ranknames, comm->local_size * rankname_len);
145+
146+
my_rankname->rank = comm->rank;
147+
memcpy(my_rankname->name, addrname, addrnamelen);
148+
149+
/* Use an smp algorithm explicitly that only require a working node_comm and node_roots_comm. */
150+
mpi_errno = MPIR_Allgather_intra_smp_no_order(my_rankname, rankname_len, MPIR_BYTE_INTERNAL,
151+
all_ranknames, rankname_len, MPIR_BYTE_INTERNAL,
152+
comm, MPIR_ERR_NONE);
153+
MPIR_ERR_CHECK(mpi_errno);
154+
155+
/* create av, skipping existing entries */
156+
for (int i = 0; i < comm->local_size; i++) {
157+
struct rankname *p = (struct rankname *) ((char *) all_ranknames + i * rankname_len);
158+
MPIDI_av_entry_t *av = MPIDIU_comm_rank_to_av(comm, p->rank);
159+
MPIR_Lpid lpid = MPIR_comm_rank_to_lpid(comm, p->rank);
160+
UCX_AV_INSERT(av, lpid, p->name);
138161
}
162+
163+
fn_exit:
164+
MPIR_CHKLMEM_FREEALL();
139165
return mpi_errno;
140166
fn_fail:
141167
goto fn_exit;
@@ -210,7 +236,7 @@ int MPIDI_UCX_init_world(void)
210236
mpi_errno = MPIDI_UCX_init_worker(0);
211237
MPIR_ERR_CHECK(mpi_errno);
212238

213-
mpi_errno = initial_address_exchange();
239+
mpi_errno = MPIDI_UCX_comm_addr_exchange(MPIR_Process.comm_world);
214240
MPIR_ERR_CHECK(mpi_errno);
215241

216242
ucx_initialized = true;

0 commit comments

Comments
 (0)