55#include "opal_config.h"
66
77#include "common_ucx_int.h"
8- #include "common_ucx_request.h"
98#include <stdint.h>
109#include <string.h>
1110
@@ -79,7 +78,7 @@ typedef struct {
7978 pthread_key_t mem_tls_key ;
8079} opal_common_ucx_wpmem_t ;
8180
82- typedef struct {
81+ typedef struct opal_common_ucx_winfo {
8382 opal_recursive_mutex_t mutex ;
8483 volatile int released ;
8584 ucp_worker_h worker ;
@@ -95,6 +94,14 @@ typedef struct {
9594 ucp_rkey_h * rkeys ;
9695} opal_common_ucx_tlocal_fast_ptrs_t ;
9796
97+ typedef void (* opal_common_ucx_user_req_handler_t )(void * request );
98+
99+ typedef struct {
100+ void * ext_req ;
101+ opal_common_ucx_user_req_handler_t ext_cb ;
102+ opal_common_ucx_winfo_t * winfo ;
103+ } opal_common_ucx_request_t ;
104+
98105typedef enum {
99106 OPAL_COMMON_UCX_PUT ,
100107 OPAL_COMMON_UCX_GET
@@ -198,6 +205,10 @@ OPAL_DECLSPEC int opal_common_ucx_wpctx_create(opal_common_ucx_wpool_t *wpool, i
198205 opal_common_ucx_ctx_t * * ctx_ptr );
199206OPAL_DECLSPEC void opal_common_ucx_wpctx_release (opal_common_ucx_ctx_t * ctx );
200207
208+ /* request init / completion */
209+ OPAL_DECLSPEC void opal_common_ucx_req_init (void * request );
210+ OPAL_DECLSPEC void opal_common_ucx_req_completion (void * request , ucs_status_t status );
211+
201212/* Managing thread local storage */
202213OPAL_DECLSPEC int opal_common_ucx_tlocal_fetch_spath (opal_common_ucx_wpmem_t * mem , int target );
203214static inline int
@@ -246,10 +257,57 @@ OPAL_DECLSPEC int opal_common_ucx_wpmem_flush(opal_common_ucx_wpmem_t *mem,
246257 int target );
247258OPAL_DECLSPEC int opal_common_ucx_wpmem_fence (opal_common_ucx_wpmem_t * mem );
248259
249- OPAL_DECLSPEC int opal_common_ucx_flush (ucp_ep_h ep , ucp_worker_h worker ,
250- opal_common_ucx_flush_type_t type ,
251- opal_common_ucx_flush_scope_t scope ,
252- ucs_status_ptr_t * req_ptr );
260+ OPAL_DECLSPEC int opal_common_ucx_winfo_flush (opal_common_ucx_winfo_t * winfo , int target ,
261+ opal_common_ucx_flush_type_t type ,
262+ opal_common_ucx_flush_scope_t scope ,
263+ ucs_status_ptr_t * req_ptr );
264+
265+ static inline
266+ int opal_common_ucx_wait_request_mt (ucs_status_ptr_t request , const char * msg )
267+ {
268+ ucs_status_t status ;
269+ int ctr = 0 , ret = 0 ;
270+ opal_common_ucx_winfo_t * winfo ;
271+
272+ /* check for request completed or failed */
273+ if (OPAL_LIKELY (UCS_OK == request )) {
274+ return OPAL_SUCCESS ;
275+ } else if (OPAL_UNLIKELY (UCS_PTR_IS_ERR (request ))) {
276+ MCA_COMMON_UCX_VERBOSE (1 , "%s failed: %d, %s" , msg ? msg : __func__ ,
277+ UCS_PTR_STATUS (request ),
278+ ucs_status_string (UCS_PTR_STATUS (request )));
279+ return OPAL_ERROR ;
280+ }
281+
282+ winfo = ((opal_common_ucx_request_t * )request )-> winfo ;
283+ assert (winfo != NULL );
284+
285+ do {
286+ ctr = opal_common_ucx .progress_iterations ;
287+ opal_mutex_lock (& winfo -> mutex );
288+ do {
289+ ret = ucp_worker_progress (winfo -> worker );
290+ status = opal_common_ucx_request_status (request );
291+ if (status != UCS_INPROGRESS ) {
292+ ucp_request_free (request );
293+ if (OPAL_UNLIKELY (UCS_OK != status )) {
294+ MCA_COMMON_UCX_VERBOSE (1 , "%s failed: %d, %s" ,
295+ msg ? msg : __func__ ,
296+ UCS_PTR_STATUS (request ),
297+ ucs_status_string (UCS_PTR_STATUS (request )));
298+ opal_mutex_unlock (& winfo -> mutex );
299+ return OPAL_ERROR ;
300+ }
301+ break ;
302+ }
303+ ctr -- ;
304+ } while (ctr > 0 && ret > 0 && status == UCS_INPROGRESS );
305+ opal_mutex_unlock (& winfo -> mutex );
306+ opal_progress ();
307+ } while (status == UCS_INPROGRESS );
308+
309+ return OPAL_SUCCESS ;
310+ }
253311
254312static inline int _periodical_flush_nb (opal_common_ucx_wpmem_t * mem ,
255313 opal_common_ucx_winfo_t * winfo ,
@@ -264,8 +322,8 @@ static inline int _periodical_flush_nb(opal_common_ucx_wpmem_t *mem,
264322 opal_common_ucx_flush_scope_t scope ;
265323
266324 if (winfo -> inflight_req != UCS_OK ) {
267- rc = opal_common_ucx_wait_request (winfo -> inflight_req , winfo -> worker ,
268- "opal_common_ucx_flush_nb" );
325+ rc = opal_common_ucx_wait_request_mt (winfo -> inflight_req ,
326+ "opal_common_ucx_flush_nb" );
269327 if (OPAL_UNLIKELY (OPAL_SUCCESS != rc )){
270328 MCA_COMMON_UCX_VERBOSE (1 , "opal_common_ucx_wait_request failed: %d" , rc );
271329 return rc ;
@@ -283,13 +341,13 @@ static inline int _periodical_flush_nb(opal_common_ucx_wpmem_t *mem,
283341 winfo -> inflight_ops [target ] = 0 ;
284342 }
285343
286- rc = opal_common_ucx_flush (winfo -> endpoints [target ], winfo -> worker ,
287- OPAL_COMMON_UCX_FLUSH_NB_PREFERRED , scope ,
288- & winfo -> inflight_req );
344+ rc = opal_common_ucx_winfo_flush (winfo , target , OPAL_COMMON_UCX_FLUSH_NB_PREFERRED ,
345+ scope , & winfo -> inflight_req );
289346 if (OPAL_UNLIKELY (OPAL_SUCCESS != rc )){
290347 MCA_COMMON_UCX_VERBOSE (1 , "opal_common_ucx_flush failed: %d" , rc );
291348 return rc ;
292349 }
350+ ((opal_common_ucx_request_t * )winfo -> inflight_req )-> winfo = winfo ;
293351 } else if (OPAL_UNLIKELY (winfo -> inflight_req != UCS_OK )) {
294352 int ret ;
295353 do {
@@ -510,6 +568,7 @@ opal_common_ucx_wpmem_fetch_nb(opal_common_ucx_wpmem_t *mem,
510568 if (UCS_PTR_IS_PTR (req )) {
511569 req -> ext_req = user_req_ptr ;
512570 req -> ext_cb = user_req_cb ;
571+ req -> winfo = winfo ;
513572 } else {
514573 if (user_req_cb != NULL ) {
515574 (* user_req_cb )(user_req_ptr );
0 commit comments