|
5 | 5 | #include "opal_config.h" |
6 | 6 |
|
7 | 7 | #include "common_ucx_int.h" |
| 8 | +#include "common_ucx_request.h" |
8 | 9 | #include <stdint.h> |
9 | 10 |
|
10 | 11 | #include <ucp/api/ucp.h> |
@@ -176,9 +177,7 @@ static inline void opal_common_ucx_wpool_dbg_init(void) |
176 | 177 | OPAL_DECLSPEC opal_common_ucx_wpool_t * opal_common_ucx_wpool_allocate(void); |
177 | 178 | OPAL_DECLSPEC void opal_common_ucx_wpool_free(opal_common_ucx_wpool_t *wpool); |
178 | 179 | OPAL_DECLSPEC int opal_common_ucx_wpool_init(opal_common_ucx_wpool_t *wpool, |
179 | | - int proc_world_size, |
180 | | - ucp_request_init_callback_t req_init_ptr, |
181 | | - size_t req_size, bool enable_mt); |
| 180 | + int proc_world_size, bool enable_mt); |
182 | 181 | OPAL_DECLSPEC void opal_common_ucx_wpool_finalize(opal_common_ucx_wpool_t *wpool); |
183 | 182 | OPAL_DECLSPEC void opal_common_ucx_wpool_progress(opal_common_ucx_wpool_t *wpool); |
184 | 183 |
|
@@ -394,27 +393,40 @@ opal_common_ucx_wpmem_fetch(opal_common_ucx_wpmem_t *mem, |
394 | 393 |
|
395 | 394 | static inline int |
396 | 395 | opal_common_ucx_wpmem_fetch_nb(opal_common_ucx_wpmem_t *mem, |
397 | | - ucp_atomic_fetch_op_t opcode, |
398 | | - uint64_t value, |
399 | | - int target, void *buffer, size_t len, |
400 | | - uint64_t rem_addr, ucs_status_ptr_t *ptr) |
| 396 | + ucp_atomic_fetch_op_t opcode, |
| 397 | + uint64_t value, |
| 398 | + int target, void *buffer, size_t len, |
| 399 | + uint64_t rem_addr, |
| 400 | + opal_common_ucx_user_req_handler_t user_req_cb, |
| 401 | + void *user_req_ptr) |
401 | 402 | { |
402 | 403 | ucp_ep_h ep = NULL; |
403 | 404 | ucp_rkey_h rkey = NULL; |
404 | 405 | opal_common_ucx_winfo_t *winfo = NULL; |
405 | 406 | int rc = OPAL_SUCCESS; |
| 407 | + opal_common_ucx_request_t *req; |
| 408 | + |
406 | 409 | rc = opal_common_ucx_tlocal_fetch(mem, target, &ep, &rkey, &winfo); |
407 | 410 | if(OPAL_UNLIKELY(OPAL_SUCCESS != rc)){ |
408 | 411 | MCA_COMMON_UCX_ERROR("tlocal_fetch failed: %d", rc); |
409 | 412 | return rc; |
410 | 413 | } |
411 | 414 | /* Perform the operation */ |
412 | 415 | opal_mutex_lock(&winfo->mutex); |
413 | | - (*ptr) = opal_common_ucx_atomic_fetch_nb(ep, opcode, value, |
414 | | - buffer, len, |
415 | | - rem_addr, rkey, |
416 | | - winfo->worker); |
| 416 | + req = opal_common_ucx_atomic_fetch_nb(ep, opcode, value, buffer, len, |
| 417 | + rem_addr, rkey, opal_common_ucx_req_completion, |
| 418 | + winfo->worker); |
417 | 419 | opal_mutex_unlock(&winfo->mutex); |
| 420 | + |
| 421 | + if (UCS_PTR_IS_PTR(req)) { |
| 422 | + req->ext_req = user_req_ptr; |
| 423 | + req->ext_cb = user_req_cb; |
| 424 | + } else { |
| 425 | + if (user_req_cb != NULL) { |
| 426 | + (*user_req_cb)(user_req_ptr); |
| 427 | + } |
| 428 | + } |
| 429 | + |
418 | 430 | return rc; |
419 | 431 | } |
420 | 432 |
|
|
0 commit comments