diff --git a/ompi/mca/pml/ucx/pml_ucx.c b/ompi/mca/pml/ucx/pml_ucx.c index dd16a27b154..42f1d148c3a 100644 --- a/ompi/mca/pml/ucx/pml_ucx.c +++ b/ompi/mca/pml/ucx/pml_ucx.c @@ -9,6 +9,7 @@ * Copyright (c) 2019 Intel, Inc. All rights reserved. * Copyright (c) 2022 Amazon.com, Inc. or its affiliates. * All Rights reserved. + * Copyright (c) 2024 NVIDIA Corporation. * $COPYRIGHT$ * * Additional copyrights may follow @@ -95,6 +96,9 @@ mca_pml_ucx_module_t ompi_pml_ucx = { #define PML_UCX_REQ_ALLOCA() \ ((char *)alloca(ompi_pml_ucx.request_size) + ompi_pml_ucx.request_size); +#define PML_UCX_IS_VALID_ENDPOINT(PROC) \ + ((PROC)->proc_endpoints[OMPI_PROC_ENDPOINT_TAG_PML] >= (void*)1) + #if HAVE_UCP_WORKER_ADDRESS_FLAGS static int mca_pml_ucx_send_worker_address_type(int addr_flags, int modex_scope) { @@ -407,17 +411,27 @@ static ucp_ep_h mca_pml_ucx_add_proc_common(ompi_proc_t *proc) ucp_ep_params_t ep_params; ucp_address_t *address; ucs_status_t status; - ucp_ep_h ep; + ucp_ep_h ep = NULL; int ret; + check_again: /* Do not add a new endpoint if we already created one */ - if (NULL != proc->proc_endpoints[OMPI_PROC_ENDPOINT_TAG_PML]) { + if (PML_UCX_IS_VALID_ENDPOINT(proc)) { return proc->proc_endpoints[OMPI_PROC_ENDPOINT_TAG_PML]; } + if (!OPAL_ATOMIC_COMPARE_EXCHANGE_STRONG_PTR((opal_atomic_intptr_t *)&proc->proc_endpoints[OMPI_PROC_ENDPOINT_TAG_PML], + &ep, (void*)1)) { + /* give some slack to the other thread to create the endpooint */ + while (!PML_UCX_IS_VALID_ENDPOINT(proc)) { + _opal_lifo_release_cpu(); + opal_atomic_rmb(); + } + goto check_again; + } ret = mca_pml_ucx_recv_worker_address(proc, &address, &addrlen); if (ret < 0) { - return NULL; + goto return_with_failure; } PML_UCX_VERBOSE(2, "connecting to proc. %d", proc->super.proc_name.vpid); @@ -431,11 +445,16 @@ static ucp_ep_h mca_pml_ucx_add_proc_common(ompi_proc_t *proc) PML_UCX_ERROR("ucp_ep_create(proc=%d) failed: %s", proc->super.proc_name.vpid, ucs_status_string(status)); - return NULL; + goto return_with_failure; } - + opal_atomic_wmb(); /* make sure the ep is visible */ proc->proc_endpoints[OMPI_PROC_ENDPOINT_TAG_PML] = ep; return ep; + + return_with_failure: + /* we are responsible for setting the proc_endpoint to NULL */ + proc->proc_endpoints[OMPI_PROC_ENDPOINT_TAG_PML] = NULL; + return NULL; } int mca_pml_ucx_add_procs(struct ompi_proc_t **procs, size_t nprocs)