diff --git a/opal/mca/btl/uct/btl_uct_endpoint.c b/opal/mca/btl/uct/btl_uct_endpoint.c index aa255a0f7f1..695fd754aa2 100644 --- a/opal/mca/btl/uct/btl_uct_endpoint.c +++ b/opal/mca/btl/uct/btl_uct_endpoint.c @@ -38,6 +38,9 @@ static void mca_btl_uct_endpoint_destruct(mca_btl_uct_endpoint_t *endpoint) } OBJ_DESTRUCT(&endpoint->ep_lock); + if (endpoint->conn_ep) { + OBJ_RELEASE(endpoint->conn_ep); + } } OBJ_CLASS_INSTANCE(mca_btl_uct_endpoint_t, opal_object_t, mca_btl_uct_endpoint_construct, @@ -206,7 +209,6 @@ static int mca_btl_uct_endpoint_send_conn_req(mca_btl_uct_module_t *uct_btl, mca_btl_uct_conn_req_t *request, size_t request_length) { - mca_btl_uct_connection_ep_t *conn_ep = endpoint->conn_ep; mca_btl_uct_conn_completion_t completion = {.super = {.count = 1, .func = mca_btl_uct_endpoint_flush_complete}, .complete = false}; ucs_status_t ucs_status; @@ -215,15 +217,13 @@ static int mca_btl_uct_endpoint_send_conn_req(mca_btl_uct_module_t *uct_btl, ("sending connection request to peer. context id: %d, type: %d, length: %" PRIsize_t, request->context_id, request->type, request_length)); - OBJ_RETAIN(endpoint->conn_ep); - /* need to drop the lock to avoid hold-and-wait */ opal_mutex_unlock(&endpoint->ep_lock); do { MCA_BTL_UCT_CONTEXT_SERIALIZE(conn_tl_context, { - ucs_status = uct_ep_am_short(conn_ep->uct_ep, MCA_BTL_UCT_CONNECT_RDMA, request->type, - request, request_length); + ucs_status = uct_ep_am_short(endpoint->conn_ep->uct_ep, MCA_BTL_UCT_CONNECT_RDMA, + request->type, request, request_length); }); if (OPAL_LIKELY(UCS_OK == ucs_status)) { break; @@ -238,11 +238,11 @@ static int mca_btl_uct_endpoint_send_conn_req(mca_btl_uct_module_t *uct_btl, } while (1); /* for now we just wait for the connection request to complete before continuing */ - ucs_status = uct_ep_flush(conn_ep->uct_ep, 0, &completion.super); + ucs_status = uct_ep_flush(endpoint->conn_ep->uct_ep, 0, &completion.super); if (UCS_OK != ucs_status && UCS_INPROGRESS != ucs_status) { /* NTH: I don't know if this path is needed. For some networks we must use a completion. */ do { - ucs_status = uct_ep_flush(conn_ep->uct_ep, 0, NULL); + ucs_status = uct_ep_flush(endpoint->conn_ep->uct_ep, 0, NULL); mca_btl_uct_context_progress(conn_tl_context); } while (UCS_INPROGRESS == ucs_status); } else { @@ -253,8 +253,6 @@ static int mca_btl_uct_endpoint_send_conn_req(mca_btl_uct_module_t *uct_btl, opal_mutex_lock(&endpoint->ep_lock); - OBJ_RELEASE(endpoint->conn_ep); - return OPAL_SUCCESS; } @@ -265,7 +263,6 @@ static int mca_btl_uct_endpoint_send_connection_data( { mca_btl_uct_tl_t *conn_tl = uct_btl->conn_tl; mca_btl_uct_device_context_t *conn_tl_context = conn_tl->uct_dev_contexts[0]; - mca_btl_uct_connection_ep_t *conn_ep = endpoint->conn_ep; uct_device_addr_t *device_addr = NULL; uct_iface_addr_t *iface_addr; ucs_status_t ucs_status; @@ -274,7 +271,7 @@ static int mca_btl_uct_endpoint_send_connection_data( BTL_VERBOSE(("connecting endpoint to remote endpoint")); - if (NULL == conn_ep) { + if (NULL == endpoint->conn_ep) { BTL_VERBOSE(("creating a temporary endpoint for handling connections to %p", opal_process_name_print(endpoint->ep_proc->proc_name))); @@ -282,8 +279,8 @@ static int mca_btl_uct_endpoint_send_connection_data( device_addr = (uct_device_addr_t *) ((uintptr_t) conn_tl_data + MCA_BTL_UCT_TL_ATTR(conn_tl, 0).iface_addr_len); - endpoint->conn_ep = conn_ep = OBJ_NEW(mca_btl_uct_connection_ep_t); - if (OPAL_UNLIKELY(NULL == conn_ep)) { + endpoint->conn_ep = OBJ_NEW(mca_btl_uct_connection_ep_t); + if (OPAL_UNLIKELY(NULL == endpoint->conn_ep)) { return OPAL_ERR_OUT_OF_RESOURCE; } @@ -291,7 +288,7 @@ static int mca_btl_uct_endpoint_send_connection_data( MCA_BTL_UCT_CONTEXT_SERIALIZE(conn_tl_context, { ucs_status = mca_btl_uct_ep_create_connected_compat(conn_tl_context->uct_iface, device_addr, iface_addr, - &conn_ep->uct_ep); + &endpoint->conn_ep->uct_ep); }); if (UCS_OK != ucs_status) { BTL_VERBOSE( @@ -299,6 +296,8 @@ static int mca_btl_uct_endpoint_send_connection_data( ucs_status)); return OPAL_ERROR; } + } else { + OBJ_RETAIN(endpoint->conn_ep); } size_t request_length = sizeof(mca_btl_uct_conn_req_t) @@ -368,6 +367,9 @@ static int mca_btl_uct_endpoint_connect_endpoint( if (UCS_OK != ucs_status) { return OPAL_ERROR; } + + mca_btl_uct_endpoint_set_flag(uct_btl, endpoint, tl_context->context_id, tl_endpoint, + MCA_BTL_UCT_ENDPOINT_FLAG_EP_CONNECTED); } opal_timer_t now = opal_timer_base_get_usec(); @@ -391,7 +393,6 @@ int mca_btl_uct_endpoint_connect(mca_btl_uct_module_t *uct_btl, mca_btl_uct_endp mca_btl_uct_device_context_t *tl_context = mca_btl_uct_module_get_tl_context_specific(uct_btl, tl, context_id); uint8_t *rdma_tl_data = NULL, *conn_tl_data = NULL, *am_tl_data = NULL, *tl_data; - mca_btl_uct_connection_ep_t *conn_ep = NULL; mca_btl_uct_modex_t *modex; uint8_t *modex_data; size_t msg_size; @@ -474,19 +475,8 @@ int mca_btl_uct_endpoint_connect(mca_btl_uct_module_t *uct_btl, mca_btl_uct_endp } while (0); - /* to avoid a possible hold-and wait deadlock. destroy the endpoint after dropping the endpoint - * lock. */ - if (endpoint->conn_ep && 1 == endpoint->conn_ep->super.obj_reference_count) { - conn_ep = endpoint->conn_ep; - endpoint->conn_ep = NULL; - } - opal_mutex_unlock(&endpoint->ep_lock); - if (conn_ep) { - OBJ_RELEASE(conn_ep); - } - BTL_VERBOSE(("endpoint%s ready for use", (OPAL_ERR_OUT_OF_RESOURCE != rc) ? "" : " not yet")); return rc; diff --git a/opal/mca/btl/uct/btl_uct_endpoint.h b/opal/mca/btl/uct/btl_uct_endpoint.h index 49d1b941457..0c4bae20050 100644 --- a/opal/mca/btl/uct/btl_uct_endpoint.h +++ b/opal/mca/btl/uct/btl_uct_endpoint.h @@ -14,6 +14,7 @@ * reserved. * Copyright (c) 2020 Amazon.com, Inc. or its affiliates. * All Rights reserved. + * Copyright (c) 2025 Google, LLC. All rights reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -106,5 +107,29 @@ static inline int mca_btl_uct_endpoint_check_am(mca_btl_uct_module_t *module, module->am_tl->tl_index); } +// Requires that the endpoint lock is held. +static inline void mca_btl_uct_endpoint_set_flag(mca_btl_uct_module_t *module, mca_btl_uct_endpoint_t *endpoint, + int context_id, mca_btl_uct_tl_endpoint_t *tl_endpoint, int32_t flag) { + opal_atomic_wmb(); + int32_t flag_value = opal_atomic_or_fetch_32(&tl_endpoint->flags, flag); + if ((flag_value & (MCA_BTL_UCT_ENDPOINT_FLAG_EP_CONNECTED | MCA_BTL_UCT_ENDPOINT_FLAG_CONN_REM_READY)) == + (MCA_BTL_UCT_ENDPOINT_FLAG_EP_CONNECTED | MCA_BTL_UCT_ENDPOINT_FLAG_CONN_REM_READY)) { + opal_atomic_fetch_or_32(&tl_endpoint->flags, MCA_BTL_UCT_ENDPOINT_FLAG_CONN_READY); + + opal_atomic_wmb(); + + mca_btl_uct_base_frag_t *frag; + OPAL_LIST_FOREACH (frag, &module->pending_frags, mca_btl_uct_base_frag_t) { + if (frag->context->context_id == context_id && endpoint == frag->endpoint) { + frag->ready = true; + } + } + + if (endpoint->conn_ep) { + OBJ_RELEASE(endpoint->conn_ep); + } + } +} + END_C_DECLS #endif diff --git a/opal/mca/btl/uct/btl_uct_tl.c b/opal/mca/btl/uct/btl_uct_tl.c index 5669e88c061..c1ef4c6d727 100644 --- a/opal/mca/btl/uct/btl_uct_tl.c +++ b/opal/mca/btl/uct/btl_uct_tl.c @@ -6,7 +6,7 @@ * and Technology (RIST). All rights reserved. * Copyright (c) 2018 Triad National Security, LLC. All rights * reserved. - * Copyright (c) 2019 Google, LLC. All rights reserved. + * Copyright (c) 2019-2025 Google, LLC. All rights reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -199,9 +199,6 @@ int mca_btl_uct_process_connection_request(mca_btl_uct_module_t *module, int32_t ep_flags; int rc; - BTL_VERBOSE(("got connection request for endpoint %p. type = %d. context id = %d", - (void *) endpoint, req->type, req->context_id)); - if (NULL == endpoint) { BTL_ERROR(("could not create endpoint for connection request")); return UCS_ERR_UNREACHABLE; @@ -211,6 +208,9 @@ int mca_btl_uct_process_connection_request(mca_btl_uct_module_t *module, ep_flags = opal_atomic_fetch_or_32(&tl_endpoint->flags, MCA_BTL_UCT_ENDPOINT_FLAG_CONN_REC); + BTL_VERBOSE(("got connection request for endpoint %p. type = %d. context id = %d. ep_flags = %x", + (void *) endpoint, req->type, req->context_id, ep_flags)); + if (!(ep_flags & MCA_BTL_UCT_ENDPOINT_FLAG_CONN_REC)) { /* create any necessary resources */ rc = mca_btl_uct_endpoint_connect(module, endpoint, req->context_id, req->ep_addr, @@ -225,22 +225,13 @@ int mca_btl_uct_process_connection_request(mca_btl_uct_module_t *module, * message. this might be overkill but there is little documentation at the UCT level on when * an endpoint can be used. */ if (req->type == 1) { - /* remote side is ready */ - mca_btl_uct_base_frag_t *frag; - + /* remote side is connected */ /* to avoid a race with send adding pending frags grab the lock here */ OPAL_THREAD_SCOPED_LOCK(&endpoint->ep_lock, { BTL_VERBOSE(("connection ready. sending %" PRIsize_t " frags", opal_list_get_size(&module->pending_frags))); - (void) opal_atomic_or_fetch_32(&tl_endpoint->flags, - MCA_BTL_UCT_ENDPOINT_FLAG_CONN_READY); - opal_atomic_wmb(); - - OPAL_LIST_FOREACH (frag, &module->pending_frags, mca_btl_uct_base_frag_t) { - if (frag->context->context_id == req->context_id && endpoint == frag->endpoint) { - frag->ready = true; - } - } + mca_btl_uct_endpoint_set_flag(module, endpoint, req->context_id, tl_endpoint, + MCA_BTL_UCT_ENDPOINT_FLAG_CONN_REM_READY); }); } diff --git a/opal/mca/btl/uct/btl_uct_types.h b/opal/mca/btl/uct/btl_uct_types.h index cd331986b8a..b2bac61be61 100644 --- a/opal/mca/btl/uct/btl_uct_types.h +++ b/opal/mca/btl/uct/btl_uct_types.h @@ -27,8 +27,10 @@ struct mca_btl_uct_base_frag_t; # define MCA_BTL_UCT_ENDPOINT_FLAG_CONN_REC 0x1 /** remote endpoint read */ # define MCA_BTL_UCT_ENDPOINT_FLAG_CONN_REM_READY 0x2 +/** local UCT endpoint connected */ +# define MCA_BTL_UCT_ENDPOINT_FLAG_EP_CONNECTED 0x4 /** connection was established */ -# define MCA_BTL_UCT_ENDPOINT_FLAG_CONN_READY 0x4 +# define MCA_BTL_UCT_ENDPOINT_FLAG_CONN_READY 0x8 /* AM tags */ /** BTL fragment */