Skip to content

Commit 4dea369

Browse files
committed
ch4: remove the usage of MPIDI_global.is_initialized
Now we separated the multivci initialization from the base init, we no longer need to use MPIDI_global.is_initialized to guard pre-mature access to extra vcis during init. We still need a flag to check whether the root endpoints has been setup before we try to tear them down in OFI/UCX finalize. Use a static initialized flag for this purpose.
1 parent 3a4f0cf commit 4dea369

File tree

3 files changed

+30
-29
lines changed

3 files changed

+30
-29
lines changed

src/mpid/ch4/netmod/ofi/ofi_init.c

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -549,6 +549,8 @@ categories :
549549
=== END_MPI_T_CVAR_INFO_BLOCK ===
550550
*/
551551

552+
static bool ofi_initialized = false;
553+
552554
static int update_global_limits(struct fi_info *prov);
553555
static void dump_global_settings(void);
554556
static int destroy_vci_context(int vci, int nic);
@@ -799,6 +801,8 @@ int MPIDI_OFI_init_world(void)
799801
MPIR_ERR_CHECK(mpi_errno);
800802
}
801803

804+
ofi_initialized = true;
805+
802806
fn_exit:
803807
return mpi_errno;
804808
fn_fail:
@@ -879,10 +883,14 @@ static int flush_send_queue(void)
879883
{
880884
int mpi_errno = MPI_SUCCESS;
881885

886+
if (!ofi_initialized) {
887+
goto fn_exit;
888+
}
889+
882890
MPIDI_OFI_dynamic_process_request_t *reqs;
883891
/* TODO - Iterate over each NIC in addition to each VNI when multi-NIC within the same
884892
* process is implemented. */
885-
int num_vcis = (MPIDI_global.is_initialized ? MPIDI_OFI_global.num_vcis : 1);
893+
int num_vcis = MPIDI_OFI_global.num_vcis;
886894
int num_reqs = num_vcis * 2;
887895
reqs = MPL_malloc(sizeof(MPIDI_OFI_dynamic_process_request_t) * num_reqs, MPL_MEM_OTHER);
888896

@@ -937,12 +945,9 @@ int MPIDI_OFI_mpi_finalize_hook(void)
937945
MPIDI_OFI_mr_key_allocator_destroy();
938946

939947
if (strcmp("sockets", MPIDI_OFI_global.prov_use[0]->fabric_attr->prov_name) == 0) {
940-
/* sockets provider need flush any last lightweight send. Only do it if we initialized
941-
* world. Sockets provider can't even send self messages otherwise. */
942-
if (MPIDI_global.is_initialized) {
943-
mpi_errno = flush_send_queue();
944-
MPIR_ERR_CHECK(mpi_errno);
945-
}
948+
/* sockets provider need flush any last lightweight send. */
949+
mpi_errno = flush_send_queue();
950+
MPIR_ERR_CHECK(mpi_errno);
946951
} else if (MPIR_CVAR_NO_COLLECTIVE_FINALIZE) {
947952
/* skip collective work arounds */
948953
} else if (strcmp("verbs;ofi_rxm", MPIDI_OFI_global.prov_use[0]->fabric_attr->prov_name) == 0
@@ -980,12 +985,10 @@ int MPIDI_OFI_mpi_finalize_hook(void)
980985
/* Tearing down endpoints in reverse order they were created */
981986
for (int nic = MPIDI_OFI_global.num_nics - 1; nic >= 0; nic--) {
982987
for (int vci = MPIDI_OFI_global.num_vcis - 1; vci >= 0; vci--) {
983-
if (MPIDI_global.is_initialized || (vci == 0 && nic == 0)) {
984-
/* If the user has not freed all MPI objects, ofi might not shut down cleanly.
985-
* We intentionally ignore errors to avoid crashing in finalize. Debug builds
986-
* will warn about unfreed objects/memory. */
987-
(void) destroy_vci_context(vci, nic);
988-
}
988+
/* If the user has not freed all MPI objects, ofi might not shut down cleanly.
989+
* We intentionally ignore errors to avoid crashing in finalize. Debug builds
990+
* will warn about unfreed objects/memory. */
991+
(void) destroy_vci_context(vci, nic);
989992
}
990993
}
991994

src/mpid/ch4/netmod/ucx/ucx_init.c

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ categories :
1818
=== END_MPI_T_CVAR_INFO_BLOCK ===
1919
*/
2020

21+
static bool ucx_initialized = false;
22+
2123
static void request_init_callback(void *request)
2224
{
2325

@@ -211,6 +213,8 @@ int MPIDI_UCX_init_world(void)
211213
mpi_errno = initial_address_exchange();
212214
MPIR_ERR_CHECK(mpi_errno);
213215

216+
ucx_initialized = true;
217+
214218
fn_exit:
215219
return mpi_errno;
216220
fn_fail:
@@ -231,13 +235,12 @@ int MPIDI_UCX_mpi_finalize_hook(void)
231235
{
232236
int mpi_errno = MPI_SUCCESS;
233237

234-
if (!MPIDI_global.is_initialized) {
235-
/* Nothing to do */
236-
return mpi_errno;
237-
}
238-
239238
ucs_status_ptr_t ucp_request;
240-
ucs_status_ptr_t *pending;
239+
ucs_status_ptr_t *pending = NULL;
240+
241+
if (!ucx_initialized) {
242+
goto fn_exit;
243+
}
241244

242245
int n = MPIDI_UCX_global.num_vcis;
243246
pending = MPL_malloc(sizeof(ucs_status_ptr_t) * MPIR_Process.size * n * n, MPL_MEM_OTHER);

src/mpid/ch4/src/ch4_progress.h

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ extern MPL_TLS int global_vci_poll_count;
3535

3636
MPL_STATIC_INLINE_PREFIX int MPIDI_do_global_progress(void)
3737
{
38-
if (MPIDI_global.n_vcis == 1 || !MPIDI_global.is_initialized || !MPIR_CVAR_CH4_GLOBAL_PROGRESS) {
38+
if (MPIDI_global.n_vcis == 1 || !MPIR_CVAR_CH4_GLOBAL_PROGRESS) {
3939
return 0;
4040
} else {
4141
global_vci_poll_count++;
@@ -161,16 +161,11 @@ MPL_STATIC_INLINE_PREFIX void MPIDI_progress_state_init(MPID_Progress_state * st
161161
state->flag |= MPIDI_PROGRESS_NM_LOCKLESS;
162162
}
163163

164-
if (!MPIDI_global.is_initialized) {
165-
state->vci[0] = 0;
166-
state->vci_count = 1;
167-
} else {
168-
/* global progress by default */
169-
for (int i = 0; i < MPIDI_global.n_vcis; i++) {
170-
state->vci[i] = i;
171-
}
172-
state->vci_count = MPIDI_global.n_vcis;
164+
/* global progress by default */
165+
for (int i = 0; i < MPIDI_global.n_vcis; i++) {
166+
state->vci[i] = i;
173167
}
168+
state->vci_count = MPIDI_global.n_vcis;
174169
}
175170

176171
MPL_STATIC_INLINE_PREFIX int MPIDI_Progress_test(int flags)

0 commit comments

Comments
 (0)