Skip to content

Commit 3a4f0cf

Browse files
committed
ch4/ucx: move multivci code to ucx_vci.c
All multivci related initialization code in ucx_vci.c.
1 parent dc6e07e commit 3a4f0cf

File tree

3 files changed

+101
-107
lines changed

3 files changed

+101
-107
lines changed

src/mpid/ch4/netmod/ucx/ucx_impl.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,9 @@ MPL_STATIC_INLINE_PREFIX bool MPIDI_UCX_is_reachable_target(int rank, MPIR_Win *
129129

130130
#define MPIDI_UCX_WIN_AV_TO_EP(av, vci, vci_target) MPIDI_UCX_AV((av)).dest[vci][vci_target]
131131

132+
int MPIDI_UCX_init_world(void);
133+
int MPIDI_UCX_init_worker(int vci);
134+
132135
/* am handler for message sent by ucp_am_send_nb */
133136
ucs_status_t MPIDI_UCX_am_handler(void *arg, void *data, size_t length, ucp_ep_h reply_ep,
134137
unsigned flags);

src/mpid/ch4/netmod/ucx/ucx_init.c

Lines changed: 1 addition & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,6 @@ categories :
1818
=== END_MPI_T_CVAR_INFO_BLOCK ===
1919
*/
2020

21-
static void request_init_callback(void *request);
22-
2321
static void request_init_callback(void *request)
2422
{
2523

@@ -28,8 +26,6 @@ static void request_init_callback(void *request)
2826

2927
}
3028

31-
static void flush_all(void);
32-
3329
int MPIDI_UCX_init_worker(int vci)
3430
{
3531
int mpi_errno = MPI_SUCCESS;
@@ -143,68 +139,6 @@ static int initial_address_exchange(void)
143139
goto fn_exit;
144140
}
145141

146-
int MPIDI_UCX_all_vcis_address_exchange(void)
147-
{
148-
int mpi_errno = MPI_SUCCESS;
149-
150-
int size = MPIR_Process.size;
151-
int rank = MPIR_Process.rank;
152-
int num_vcis = MPIDI_UCX_global.num_vcis;
153-
154-
/* ucx address lengths are non-uniform, use MPID_MAX_BC_SIZE */
155-
size_t name_len = MPID_MAX_BC_SIZE;
156-
157-
int my_len = num_vcis * name_len;
158-
char *all_names = MPL_malloc(size * my_len, MPL_MEM_ADDRESS);
159-
MPIR_Assert(all_names);
160-
161-
char *my_names = all_names + rank * my_len;
162-
163-
/* put in my addrnames */
164-
for (int i = 0; i < num_vcis; i++) {
165-
char *vci_addrname = my_names + i * name_len;
166-
memcpy(vci_addrname, MPIDI_UCX_global.ctx[i].if_address,
167-
MPIDI_UCX_global.ctx[i].addrname_len);
168-
}
169-
/* Allgather */
170-
MPIR_Comm *comm = MPIR_Process.comm_world;
171-
mpi_errno = MPIR_Allgather_allcomm_auto(MPI_IN_PLACE, 0, MPIR_BYTE_INTERNAL,
172-
all_names, my_len, MPIR_BYTE_INTERNAL, comm,
173-
MPIR_ERR_NONE);
174-
MPIR_ERR_CHECK(mpi_errno);
175-
176-
/* insert the addresses */
177-
ucp_ep_params_t ep_params;
178-
for (int vci_local = 0; vci_local < num_vcis; vci_local++) {
179-
for (int r = 0; r < size; r++) {
180-
MPIDI_UCX_addr_t *av = &MPIDI_UCX_AV(&MPIDIU_get_av(0, r));
181-
for (int vci_remote = 0; vci_remote < num_vcis; vci_remote++) {
182-
if (vci_local == 0 && vci_remote == 0) {
183-
/* don't overwrite existing addr, or bad things will happen */
184-
continue;
185-
}
186-
int idx = r * num_vcis + vci_remote;
187-
ep_params.field_mask = UCP_EP_PARAM_FIELD_REMOTE_ADDRESS;
188-
ep_params.address = (ucp_address_t *) (all_names + idx * name_len);
189-
190-
ucs_status_t ucx_status;
191-
ucx_status = ucp_ep_create(MPIDI_UCX_global.ctx[vci_local].worker,
192-
&ep_params, &av->dest[vci_local][vci_remote]);
193-
MPIDI_UCX_CHK_STATUS(ucx_status);
194-
}
195-
}
196-
}
197-
198-
/* Flush all pending wireup operations or it may interfere with RMA flush_ops count. */
199-
flush_all();
200-
201-
fn_exit:
202-
MPL_free(all_names);
203-
return mpi_errno;
204-
fn_fail:
205-
goto fn_exit;
206-
}
207-
208142
int MPIDI_UCX_init_local(int *tag_bits)
209143
{
210144
int mpi_errno = MPI_SUCCESS;
@@ -234,7 +168,7 @@ int MPIDI_UCX_init_local(int *tag_bits)
234168
UCP_PARAM_FIELD_REQUEST_SIZE |
235169
UCP_PARAM_FIELD_ESTIMATED_NUM_EPS | UCP_PARAM_FIELD_REQUEST_INIT;
236170

237-
if (MPIDI_UCX_global.num_vcis > 1) {
171+
if (MPICH_IS_THREADED) {
238172
ucp_params.mt_workers_shared = 1;
239173
ucp_params.field_mask |= UCP_PARAM_FIELD_MT_WORKERS_SHARED;
240174
}
@@ -286,39 +220,10 @@ int MPIDI_UCX_init_world(void)
286220
goto fn_exit;
287221
}
288222

289-
/* static functions for MPIDI_UCX_post_init */
290-
static void flush_cb(void *request, ucs_status_t status)
291-
{
292-
}
293-
294-
static void flush_all(void)
295-
{
296-
void *reqs[MPIDI_CH4_MAX_VCIS];
297-
for (int vci = 0; vci < MPIDI_UCX_global.num_vcis; vci++) {
298-
reqs[vci] = ucp_worker_flush_nb(MPIDI_UCX_global.ctx[vci].worker, 0, &flush_cb);
299-
}
300-
for (int vci = 0; vci < MPIDI_UCX_global.num_vcis; vci++) {
301-
if (reqs[vci] == NULL) {
302-
continue;
303-
} else if (UCS_PTR_IS_ERR(reqs[vci])) {
304-
continue;
305-
} else {
306-
ucs_status_t status;
307-
do {
308-
MPID_Progress_test(NULL);
309-
status = ucp_request_check_status(reqs[vci]);
310-
} while (status == UCS_INPROGRESS);
311-
ucp_request_release(reqs[vci]);
312-
}
313-
}
314-
}
315-
316223
int MPIDI_UCX_post_init(void)
317224
{
318225
int mpi_errno = MPI_SUCCESS;
319226

320-
MPIDI_global.is_initialized = 1;
321-
322227
return mpi_errno;
323228
}
324229

src/mpid/ch4/netmod/ucx/ucx_vci.c

Lines changed: 97 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,36 +5,122 @@
55

66
#include "mpidimpl.h"
77
#include "ucx_impl.h"
8+
#include "mpidu_bc.h"
9+
10+
static int all_vcis_address_exchange(MPIR_Comm * comm);
11+
static void flush_all(void);
812

913
int MPIDI_UCX_comm_set_vcis(MPIR_Comm * comm, int num_vcis, int *all_num_vcis)
1014
{
1115
int mpi_errno = MPI_SUCCESS;
1216

17+
MPIR_Assert(MPIDI_UCX_global.num_vcis == 1);
1318
MPIDI_UCX_global.num_vcis = num_vcis;
1419

15-
/* set up local vcis */
20+
mpi_errno = MPIR_Allgather_impl(&MPIDI_UCX_global.num_vcis, 1, MPIR_INT_INTERNAL,
21+
all_num_vcis, 1, MPIR_INT_INTERNAL, comm, MPIR_ERR_NONE);
22+
MPIR_ERR_CHECK(mpi_errno);
23+
1624
for (int i = 1; i < MPIDI_UCX_global.num_vcis; i++) {
1725
mpi_errno = MPIDI_UCX_init_worker(i);
1826
MPIR_ERR_CHECK(mpi_errno);
1927
}
2028

21-
/* UCX netmod only support the same number of vcis on all procs */
22-
for (int i = 0; i < comm->local_size; i++) {
23-
all_num_vcis[i] = num_vcis;
24-
}
25-
26-
/* address exchange */
27-
if (num_vcis > 1) {
28-
mpi_errno = MPIDI_UCX_all_vcis_address_exchange();
29-
MPIR_ERR_CHECK(mpi_errno);
30-
}
29+
mpi_errno = all_vcis_address_exchange(comm);
30+
MPIR_ERR_CHECK(mpi_errno);
3131

3232
if (MPIR_CVAR_DEBUG_SUMMARY && comm->rank == 0) {
3333
printf("num_vcis: %d\n", MPIDI_UCX_global.num_vcis);
3434
}
35+
/* Flush all pending wireup operations or it may interfere with RMA flush_ops count.
36+
* Since this require progress in non-zero vcis, we need switch on is_initialized. */
37+
flush_all();
3538

3639
fn_exit:
3740
return mpi_errno;
3841
fn_fail:
3942
goto fn_exit;
4043
}
44+
45+
static int all_vcis_address_exchange(MPIR_Comm * comm)
46+
{
47+
int mpi_errno = MPI_SUCCESS;
48+
49+
int size = MPIR_Process.size;
50+
int rank = MPIR_Process.rank;
51+
int num_vcis = MPIDI_UCX_global.num_vcis;
52+
53+
/* ucx address lengths are non-uniform, use MPID_MAX_BC_SIZE */
54+
size_t name_len = MPID_MAX_BC_SIZE;
55+
56+
int my_len = num_vcis * name_len;
57+
char *all_names = MPL_malloc(size * my_len, MPL_MEM_ADDRESS);
58+
MPIR_Assert(all_names);
59+
60+
char *my_names = all_names + rank * my_len;
61+
62+
/* put in my addrnames */
63+
for (int i = 0; i < num_vcis; i++) {
64+
char *vci_addrname = my_names + i * name_len;
65+
memcpy(vci_addrname, MPIDI_UCX_global.ctx[i].if_address,
66+
MPIDI_UCX_global.ctx[i].addrname_len);
67+
}
68+
/* Allgather */
69+
mpi_errno = MPIR_Allgather_allcomm_auto(MPI_IN_PLACE, 0, MPIR_BYTE_INTERNAL,
70+
all_names, my_len, MPIR_BYTE_INTERNAL,
71+
comm, MPIR_ERR_NONE);
72+
MPIR_ERR_CHECK(mpi_errno);
73+
74+
/* insert the addresses */
75+
ucp_ep_params_t ep_params;
76+
for (int vci_local = 0; vci_local < num_vcis; vci_local++) {
77+
for (int r = 0; r < size; r++) {
78+
MPIDI_UCX_addr_t *av = &MPIDI_UCX_AV(&MPIDIU_get_av(0, r));
79+
for (int vci_remote = 0; vci_remote < num_vcis; vci_remote++) {
80+
if (vci_local == 0 && vci_remote == 0) {
81+
/* don't overwrite existing addr, or bad things will happen */
82+
continue;
83+
}
84+
int idx = r * num_vcis + vci_remote;
85+
ep_params.field_mask = UCP_EP_PARAM_FIELD_REMOTE_ADDRESS;
86+
ep_params.address = (ucp_address_t *) (all_names + idx * name_len);
87+
88+
ucs_status_t ucx_status;
89+
ucx_status = ucp_ep_create(MPIDI_UCX_global.ctx[vci_local].worker,
90+
&ep_params, &av->dest[vci_local][vci_remote]);
91+
MPIDI_UCX_CHK_STATUS(ucx_status);
92+
}
93+
}
94+
}
95+
fn_exit:
96+
MPL_free(all_names);
97+
return mpi_errno;
98+
fn_fail:
99+
goto fn_exit;
100+
}
101+
102+
static void flush_cb(void *request, ucs_status_t status)
103+
{
104+
}
105+
106+
static void flush_all(void)
107+
{
108+
void *reqs[MPIDI_CH4_MAX_VCIS];
109+
for (int vci = 0; vci < MPIDI_UCX_global.num_vcis; vci++) {
110+
reqs[vci] = ucp_worker_flush_nb(MPIDI_UCX_global.ctx[vci].worker, 0, &flush_cb);
111+
}
112+
for (int vci = 0; vci < MPIDI_UCX_global.num_vcis; vci++) {
113+
if (reqs[vci] == NULL) {
114+
continue;
115+
} else if (UCS_PTR_IS_ERR(reqs[vci])) {
116+
continue;
117+
} else {
118+
ucs_status_t status;
119+
do {
120+
MPID_Progress_test(NULL);
121+
status = ucp_request_check_status(reqs[vci]);
122+
} while (status == UCS_INPROGRESS);
123+
ucp_request_release(reqs[vci]);
124+
}
125+
}
126+
}

0 commit comments

Comments
 (0)