Skip to content

Commit f7e1772

Browse files
authored
Merge pull request pmodels#7355 from hzhou/2503_comm_vci
ch4: support comm_set_vcis in ucx and posix Approved-by: Ken Reffenetti
2 parents dc6e07e + 8b52d55 commit f7e1772

23 files changed

+488
-253
lines changed

src/mpid/ch4/netmod/ofi/ofi_init.c

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -549,6 +549,8 @@ categories :
549549
=== END_MPI_T_CVAR_INFO_BLOCK ===
550550
*/
551551

552+
static bool ofi_initialized = false;
553+
552554
static int update_global_limits(struct fi_info *prov);
553555
static void dump_global_settings(void);
554556
static int destroy_vci_context(int vci, int nic);
@@ -799,6 +801,8 @@ int MPIDI_OFI_init_world(void)
799801
MPIR_ERR_CHECK(mpi_errno);
800802
}
801803

804+
ofi_initialized = true;
805+
802806
fn_exit:
803807
return mpi_errno;
804808
fn_fail:
@@ -879,10 +883,14 @@ static int flush_send_queue(void)
879883
{
880884
int mpi_errno = MPI_SUCCESS;
881885

886+
if (!ofi_initialized) {
887+
goto fn_exit;
888+
}
889+
882890
MPIDI_OFI_dynamic_process_request_t *reqs;
883891
/* TODO - Iterate over each NIC in addition to each VNI when multi-NIC within the same
884892
* process is implemented. */
885-
int num_vcis = (MPIDI_global.is_initialized ? MPIDI_OFI_global.num_vcis : 1);
893+
int num_vcis = MPIDI_OFI_global.num_vcis;
886894
int num_reqs = num_vcis * 2;
887895
reqs = MPL_malloc(sizeof(MPIDI_OFI_dynamic_process_request_t) * num_reqs, MPL_MEM_OTHER);
888896

@@ -937,12 +945,9 @@ int MPIDI_OFI_mpi_finalize_hook(void)
937945
MPIDI_OFI_mr_key_allocator_destroy();
938946

939947
if (strcmp("sockets", MPIDI_OFI_global.prov_use[0]->fabric_attr->prov_name) == 0) {
940-
/* sockets provider need flush any last lightweight send. Only do it if we initialized
941-
* world. Sockets provider can't even send self messages otherwise. */
942-
if (MPIDI_global.is_initialized) {
943-
mpi_errno = flush_send_queue();
944-
MPIR_ERR_CHECK(mpi_errno);
945-
}
948+
/* sockets provider need flush any last lightweight send. */
949+
mpi_errno = flush_send_queue();
950+
MPIR_ERR_CHECK(mpi_errno);
946951
} else if (MPIR_CVAR_NO_COLLECTIVE_FINALIZE) {
947952
/* skip collective work arounds */
948953
} else if (strcmp("verbs;ofi_rxm", MPIDI_OFI_global.prov_use[0]->fabric_attr->prov_name) == 0
@@ -980,12 +985,10 @@ int MPIDI_OFI_mpi_finalize_hook(void)
980985
/* Tearing down endpoints in reverse order they were created */
981986
for (int nic = MPIDI_OFI_global.num_nics - 1; nic >= 0; nic--) {
982987
for (int vci = MPIDI_OFI_global.num_vcis - 1; vci >= 0; vci--) {
983-
if (MPIDI_global.is_initialized || (vci == 0 && nic == 0)) {
984-
/* If the user has not freed all MPI objects, ofi might not shut down cleanly.
985-
* We intentionally ignore errors to avoid crashing in finalize. Debug builds
986-
* will warn about unfreed objects/memory. */
987-
(void) destroy_vci_context(vci, nic);
988-
}
988+
/* If the user has not freed all MPI objects, ofi might not shut down cleanly.
989+
* We intentionally ignore errors to avoid crashing in finalize. Debug builds
990+
* will warn about unfreed objects/memory. */
991+
(void) destroy_vci_context(vci, nic);
989992
}
990993
}
991994

src/mpid/ch4/netmod/ucx/ucx_impl.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,9 @@ MPL_STATIC_INLINE_PREFIX bool MPIDI_UCX_is_reachable_target(int rank, MPIR_Win *
129129

130130
#define MPIDI_UCX_WIN_AV_TO_EP(av, vci, vci_target) MPIDI_UCX_AV((av)).dest[vci][vci_target]
131131

132+
int MPIDI_UCX_init_world(void);
133+
int MPIDI_UCX_init_worker(int vci);
134+
132135
/* am handler for message sent by ucp_am_send_nb */
133136
ucs_status_t MPIDI_UCX_am_handler(void *arg, void *data, size_t length, ucp_ep_h reply_ep,
134137
unsigned flags);

src/mpid/ch4/netmod/ucx/ucx_init.c

Lines changed: 9 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ categories :
1818
=== END_MPI_T_CVAR_INFO_BLOCK ===
1919
*/
2020

21-
static void request_init_callback(void *request);
21+
static bool ucx_initialized = false;
2222

2323
static void request_init_callback(void *request)
2424
{
@@ -28,8 +28,6 @@ static void request_init_callback(void *request)
2828

2929
}
3030

31-
static void flush_all(void);
32-
3331
int MPIDI_UCX_init_worker(int vci)
3432
{
3533
int mpi_errno = MPI_SUCCESS;
@@ -143,68 +141,6 @@ static int initial_address_exchange(void)
143141
goto fn_exit;
144142
}
145143

146-
int MPIDI_UCX_all_vcis_address_exchange(void)
147-
{
148-
int mpi_errno = MPI_SUCCESS;
149-
150-
int size = MPIR_Process.size;
151-
int rank = MPIR_Process.rank;
152-
int num_vcis = MPIDI_UCX_global.num_vcis;
153-
154-
/* ucx address lengths are non-uniform, use MPID_MAX_BC_SIZE */
155-
size_t name_len = MPID_MAX_BC_SIZE;
156-
157-
int my_len = num_vcis * name_len;
158-
char *all_names = MPL_malloc(size * my_len, MPL_MEM_ADDRESS);
159-
MPIR_Assert(all_names);
160-
161-
char *my_names = all_names + rank * my_len;
162-
163-
/* put in my addrnames */
164-
for (int i = 0; i < num_vcis; i++) {
165-
char *vci_addrname = my_names + i * name_len;
166-
memcpy(vci_addrname, MPIDI_UCX_global.ctx[i].if_address,
167-
MPIDI_UCX_global.ctx[i].addrname_len);
168-
}
169-
/* Allgather */
170-
MPIR_Comm *comm = MPIR_Process.comm_world;
171-
mpi_errno = MPIR_Allgather_allcomm_auto(MPI_IN_PLACE, 0, MPIR_BYTE_INTERNAL,
172-
all_names, my_len, MPIR_BYTE_INTERNAL, comm,
173-
MPIR_ERR_NONE);
174-
MPIR_ERR_CHECK(mpi_errno);
175-
176-
/* insert the addresses */
177-
ucp_ep_params_t ep_params;
178-
for (int vci_local = 0; vci_local < num_vcis; vci_local++) {
179-
for (int r = 0; r < size; r++) {
180-
MPIDI_UCX_addr_t *av = &MPIDI_UCX_AV(&MPIDIU_get_av(0, r));
181-
for (int vci_remote = 0; vci_remote < num_vcis; vci_remote++) {
182-
if (vci_local == 0 && vci_remote == 0) {
183-
/* don't overwrite existing addr, or bad things will happen */
184-
continue;
185-
}
186-
int idx = r * num_vcis + vci_remote;
187-
ep_params.field_mask = UCP_EP_PARAM_FIELD_REMOTE_ADDRESS;
188-
ep_params.address = (ucp_address_t *) (all_names + idx * name_len);
189-
190-
ucs_status_t ucx_status;
191-
ucx_status = ucp_ep_create(MPIDI_UCX_global.ctx[vci_local].worker,
192-
&ep_params, &av->dest[vci_local][vci_remote]);
193-
MPIDI_UCX_CHK_STATUS(ucx_status);
194-
}
195-
}
196-
}
197-
198-
/* Flush all pending wireup operations or it may interfere with RMA flush_ops count. */
199-
flush_all();
200-
201-
fn_exit:
202-
MPL_free(all_names);
203-
return mpi_errno;
204-
fn_fail:
205-
goto fn_exit;
206-
}
207-
208144
int MPIDI_UCX_init_local(int *tag_bits)
209145
{
210146
int mpi_errno = MPI_SUCCESS;
@@ -234,7 +170,7 @@ int MPIDI_UCX_init_local(int *tag_bits)
234170
UCP_PARAM_FIELD_REQUEST_SIZE |
235171
UCP_PARAM_FIELD_ESTIMATED_NUM_EPS | UCP_PARAM_FIELD_REQUEST_INIT;
236172

237-
if (MPIDI_UCX_global.num_vcis > 1) {
173+
if (MPICH_IS_THREADED) {
238174
ucp_params.mt_workers_shared = 1;
239175
ucp_params.field_mask |= UCP_PARAM_FIELD_MT_WORKERS_SHARED;
240176
}
@@ -277,6 +213,8 @@ int MPIDI_UCX_init_world(void)
277213
mpi_errno = initial_address_exchange();
278214
MPIR_ERR_CHECK(mpi_errno);
279215

216+
ucx_initialized = true;
217+
280218
fn_exit:
281219
return mpi_errno;
282220
fn_fail:
@@ -286,53 +224,23 @@ int MPIDI_UCX_init_world(void)
286224
goto fn_exit;
287225
}
288226

289-
/* static functions for MPIDI_UCX_post_init */
290-
static void flush_cb(void *request, ucs_status_t status)
291-
{
292-
}
293-
294-
static void flush_all(void)
295-
{
296-
void *reqs[MPIDI_CH4_MAX_VCIS];
297-
for (int vci = 0; vci < MPIDI_UCX_global.num_vcis; vci++) {
298-
reqs[vci] = ucp_worker_flush_nb(MPIDI_UCX_global.ctx[vci].worker, 0, &flush_cb);
299-
}
300-
for (int vci = 0; vci < MPIDI_UCX_global.num_vcis; vci++) {
301-
if (reqs[vci] == NULL) {
302-
continue;
303-
} else if (UCS_PTR_IS_ERR(reqs[vci])) {
304-
continue;
305-
} else {
306-
ucs_status_t status;
307-
do {
308-
MPID_Progress_test(NULL);
309-
status = ucp_request_check_status(reqs[vci]);
310-
} while (status == UCS_INPROGRESS);
311-
ucp_request_release(reqs[vci]);
312-
}
313-
}
314-
}
315-
316227
int MPIDI_UCX_post_init(void)
317228
{
318229
int mpi_errno = MPI_SUCCESS;
319230

320-
MPIDI_global.is_initialized = 1;
321-
322231
return mpi_errno;
323232
}
324233

325234
int MPIDI_UCX_mpi_finalize_hook(void)
326235
{
327236
int mpi_errno = MPI_SUCCESS;
328237

329-
if (!MPIDI_global.is_initialized) {
330-
/* Nothing to do */
331-
return mpi_errno;
332-
}
333-
334238
ucs_status_ptr_t ucp_request;
335-
ucs_status_ptr_t *pending;
239+
ucs_status_ptr_t *pending = NULL;
240+
241+
if (!ucx_initialized) {
242+
goto fn_exit;
243+
}
336244

337245
int n = MPIDI_UCX_global.num_vcis;
338246
pending = MPL_malloc(sizeof(ucs_status_ptr_t) * MPIR_Process.size * n * n, MPL_MEM_OTHER);

src/mpid/ch4/netmod/ucx/ucx_vci.c

Lines changed: 97 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,36 +5,122 @@
55

66
#include "mpidimpl.h"
77
#include "ucx_impl.h"
8+
#include "mpidu_bc.h"
9+
10+
static int all_vcis_address_exchange(MPIR_Comm * comm);
11+
static void flush_all(void);
812

913
int MPIDI_UCX_comm_set_vcis(MPIR_Comm * comm, int num_vcis, int *all_num_vcis)
1014
{
1115
int mpi_errno = MPI_SUCCESS;
1216

17+
MPIR_Assert(MPIDI_UCX_global.num_vcis == 1);
1318
MPIDI_UCX_global.num_vcis = num_vcis;
1419

15-
/* set up local vcis */
20+
mpi_errno = MPIR_Allgather_impl(&MPIDI_UCX_global.num_vcis, 1, MPIR_INT_INTERNAL,
21+
all_num_vcis, 1, MPIR_INT_INTERNAL, comm, MPIR_ERR_NONE);
22+
MPIR_ERR_CHECK(mpi_errno);
23+
1624
for (int i = 1; i < MPIDI_UCX_global.num_vcis; i++) {
1725
mpi_errno = MPIDI_UCX_init_worker(i);
1826
MPIR_ERR_CHECK(mpi_errno);
1927
}
2028

21-
/* UCX netmod only support the same number of vcis on all procs */
22-
for (int i = 0; i < comm->local_size; i++) {
23-
all_num_vcis[i] = num_vcis;
24-
}
25-
26-
/* address exchange */
27-
if (num_vcis > 1) {
28-
mpi_errno = MPIDI_UCX_all_vcis_address_exchange();
29-
MPIR_ERR_CHECK(mpi_errno);
30-
}
29+
mpi_errno = all_vcis_address_exchange(comm);
30+
MPIR_ERR_CHECK(mpi_errno);
3131

3232
if (MPIR_CVAR_DEBUG_SUMMARY && comm->rank == 0) {
3333
printf("num_vcis: %d\n", MPIDI_UCX_global.num_vcis);
3434
}
35+
/* Flush all pending wireup operations or it may interfere with RMA flush_ops count.
36+
* Since this require progress in non-zero vcis, we need switch on is_initialized. */
37+
flush_all();
3538

3639
fn_exit:
3740
return mpi_errno;
3841
fn_fail:
3942
goto fn_exit;
4043
}
44+
45+
static int all_vcis_address_exchange(MPIR_Comm * comm)
46+
{
47+
int mpi_errno = MPI_SUCCESS;
48+
49+
int size = MPIR_Process.size;
50+
int rank = MPIR_Process.rank;
51+
int num_vcis = MPIDI_UCX_global.num_vcis;
52+
53+
/* ucx address lengths are non-uniform, use MPID_MAX_BC_SIZE */
54+
size_t name_len = MPID_MAX_BC_SIZE;
55+
56+
int my_len = num_vcis * name_len;
57+
char *all_names = MPL_malloc(size * my_len, MPL_MEM_ADDRESS);
58+
MPIR_Assert(all_names);
59+
60+
char *my_names = all_names + rank * my_len;
61+
62+
/* put in my addrnames */
63+
for (int i = 0; i < num_vcis; i++) {
64+
char *vci_addrname = my_names + i * name_len;
65+
memcpy(vci_addrname, MPIDI_UCX_global.ctx[i].if_address,
66+
MPIDI_UCX_global.ctx[i].addrname_len);
67+
}
68+
/* Allgather */
69+
mpi_errno = MPIR_Allgather_allcomm_auto(MPI_IN_PLACE, 0, MPIR_BYTE_INTERNAL,
70+
all_names, my_len, MPIR_BYTE_INTERNAL,
71+
comm, MPIR_ERR_NONE);
72+
MPIR_ERR_CHECK(mpi_errno);
73+
74+
/* insert the addresses */
75+
ucp_ep_params_t ep_params;
76+
for (int vci_local = 0; vci_local < num_vcis; vci_local++) {
77+
for (int r = 0; r < size; r++) {
78+
MPIDI_UCX_addr_t *av = &MPIDI_UCX_AV(&MPIDIU_get_av(0, r));
79+
for (int vci_remote = 0; vci_remote < num_vcis; vci_remote++) {
80+
if (vci_local == 0 && vci_remote == 0) {
81+
/* don't overwrite existing addr, or bad things will happen */
82+
continue;
83+
}
84+
int idx = r * num_vcis + vci_remote;
85+
ep_params.field_mask = UCP_EP_PARAM_FIELD_REMOTE_ADDRESS;
86+
ep_params.address = (ucp_address_t *) (all_names + idx * name_len);
87+
88+
ucs_status_t ucx_status;
89+
ucx_status = ucp_ep_create(MPIDI_UCX_global.ctx[vci_local].worker,
90+
&ep_params, &av->dest[vci_local][vci_remote]);
91+
MPIDI_UCX_CHK_STATUS(ucx_status);
92+
}
93+
}
94+
}
95+
fn_exit:
96+
MPL_free(all_names);
97+
return mpi_errno;
98+
fn_fail:
99+
goto fn_exit;
100+
}
101+
102+
static void flush_cb(void *request, ucs_status_t status)
103+
{
104+
}
105+
106+
static void flush_all(void)
107+
{
108+
void *reqs[MPIDI_CH4_MAX_VCIS];
109+
for (int vci = 0; vci < MPIDI_UCX_global.num_vcis; vci++) {
110+
reqs[vci] = ucp_worker_flush_nb(MPIDI_UCX_global.ctx[vci].worker, 0, &flush_cb);
111+
}
112+
for (int vci = 0; vci < MPIDI_UCX_global.num_vcis; vci++) {
113+
if (reqs[vci] == NULL) {
114+
continue;
115+
} else if (UCS_PTR_IS_ERR(reqs[vci])) {
116+
continue;
117+
} else {
118+
ucs_status_t status;
119+
do {
120+
MPID_Progress_test(NULL);
121+
status = ucp_request_check_status(reqs[vci]);
122+
} while (status == UCS_INPROGRESS);
123+
ucp_request_release(reqs[vci]);
124+
}
125+
}
126+
}

src/mpid/ch4/shm/posix/eager/include/posix_eager.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
typedef int (*MPIDI_POSIX_eager_init_t) (int rank, int size);
1414
typedef int (*MPIDI_POSIX_eager_post_init_t) (void);
15+
typedef int (*MPIDI_POSIX_eager_set_vcis_t) (MPIR_Comm * comm, int num_vcis);
1516
typedef int (*MPIDI_POSIX_eager_finalize_t) (void);
1617

1718
typedef int (*MPIDI_POSIX_eager_send_t) (int grank, MPIDI_POSIX_am_header_t * msg_hdr,
@@ -37,6 +38,7 @@ typedef size_t(*MPIDI_POSIX_eager_buf_limit_t) (void);
3738
typedef struct {
3839
MPIDI_POSIX_eager_init_t init;
3940
MPIDI_POSIX_eager_post_init_t post_init;
41+
MPIDI_POSIX_eager_set_vcis_t set_vcis;
4042
MPIDI_POSIX_eager_finalize_t finalize;
4143

4244
MPIDI_POSIX_eager_send_t send;
@@ -59,6 +61,7 @@ extern char MPIDI_POSIX_eager_strings[][MPIDI_MAX_POSIX_EAGER_STRING_LEN];
5961

6062
int MPIDI_POSIX_eager_init(int rank, int size);
6163
int MPIDI_POSIX_eager_post_init(void);
64+
int MPIDI_POSIX_eager_set_vcis(MPIR_Comm * comm, int num_vcis);
6265
int MPIDI_POSIX_eager_finalize(void);
6366

6467
MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_eager_send(int grank, MPIDI_POSIX_am_header_t * msg_hdr,

0 commit comments

Comments
 (0)