Skip to content

Commit 557ae80

Browse files
committed
UCX osc: allow for overlap with (some) request-based atomic operations
Signed-off-by: Joseph Schuchart <[email protected]>
1 parent 1a3c6bb commit 557ae80

File tree

1 file changed

+124
-49
lines changed

1 file changed

+124
-49
lines changed

ompi/mca/osc/ucx/osc_ucx_comm.c

Lines changed: 124 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -323,7 +323,7 @@ static inline int get_dynamic_win_info(uint64_t remote_addr, ompi_osc_ucx_module
323323
return ret;
324324
}
325325

326-
static int atomic_op_replace_sum(
326+
static int do_atomic_op_replace_sum(
327327
ompi_osc_ucx_module_t *module,
328328
struct ompi_op_t *op,
329329
int target,
@@ -333,7 +333,8 @@ static int atomic_op_replace_sum(
333333
ptrdiff_t target_disp,
334334
int target_count,
335335
struct ompi_datatype_t *target_dt,
336-
void *result_addr)
336+
void *result_addr,
337+
ompi_osc_ucx_request_t *ucx_req)
337338
{
338339
int ret = OMPI_SUCCESS;
339340
size_t origin_dt_bytes;
@@ -363,12 +364,27 @@ static int atomic_op_replace_sum(
363364
opcode = UCP_ATOMIC_FETCH_OP_FADD;
364365
}
365366

367+
opal_common_ucx_user_req_handler_t user_req_cb = NULL;
368+
void* user_req_ptr = NULL;
366369
for (int i = 0; i < origin_count; ++i) {
367370
uint64_t value = 0;
371+
if ((origin_count - 1) == i && NULL != ucx_req) {
372+
// the last item is used to feed the request, if needed
373+
user_req_cb = &req_completion;
374+
user_req_ptr = ucx_req;
375+
// issue a fence if this is the last but not the only element
376+
if (0 < i) {
377+
ret = opal_common_ucx_wpmem_fence(module->mem);
378+
if (ret != OMPI_SUCCESS) {
379+
OSC_UCX_VERBOSE(1, "opal_common_ucx_mem_fence failed: %d", ret);
380+
return OMPI_ERROR;
381+
}
382+
}
383+
}
368384
memcpy(&value, origin_addr, origin_dt_bytes);
369385
ret = opal_common_ucx_wpmem_fetch_nb(module->mem, opcode, value, target,
370386
result_addr ? result_addr : &(module->req_result),
371-
origin_dt_bytes, remote_addr, NULL, NULL);
387+
origin_dt_bytes, remote_addr, user_req_cb, user_req_ptr);
372388

373389
// advance origin and remote address
374390
origin_addr = (void*)((intptr_t)origin_addr + origin_dt_bytes);
@@ -381,7 +397,7 @@ static int atomic_op_replace_sum(
381397
return ret;
382398
}
383399

384-
static int atomic_op_cswap(
400+
static int do_atomic_op_cswap(
385401
ompi_osc_ucx_module_t *module,
386402
struct ompi_op_t *op,
387403
int target,
@@ -391,7 +407,8 @@ static int atomic_op_cswap(
391407
ptrdiff_t target_disp,
392408
int target_count,
393409
struct ompi_datatype_t *target_dt,
394-
void *result_addr)
410+
void *result_addr,
411+
ompi_osc_ucx_request_t *ucx_req)
395412
{
396413
int ret = OMPI_SUCCESS;
397414
size_t origin_dt_bytes;
@@ -432,6 +449,7 @@ static int atomic_op_cswap(
432449
return ret;
433450
}
434451

452+
/* JS: move this loop into the request to overlap multiple cas operations? */
435453
do {
436454

437455
tmp_val = target_val;
@@ -451,6 +469,8 @@ static int atomic_op_cswap(
451469
break;
452470
}
453471

472+
target_val = tmp_val;
473+
454474
} while (1);
455475

456476
// store the result if necessary
@@ -463,6 +483,41 @@ static int atomic_op_cswap(
463483
remote_addr += origin_dt_bytes;
464484
}
465485

486+
if (NULL != ucx_req) {
487+
// nothing to wait for so mark the request as completed
488+
ompi_request_complete(&ucx_req->super, true);
489+
}
490+
491+
return ret;
492+
}
493+
494+
static inline
495+
int do_atomic_op(
496+
ompi_osc_ucx_module_t *module,
497+
struct ompi_op_t *op,
498+
int target,
499+
const void *origin_addr,
500+
int origin_count,
501+
struct ompi_datatype_t *origin_dt,
502+
ptrdiff_t target_disp,
503+
int target_count,
504+
struct ompi_datatype_t *target_dt,
505+
void *result_addr,
506+
ompi_osc_ucx_request_t *ucx_req)
507+
{
508+
int ret;
509+
510+
if (op == &ompi_mpi_op_replace.op || op == &ompi_mpi_op_sum.op) {
511+
ret = do_atomic_op_replace_sum(module, op, target,
512+
origin_addr, origin_count, origin_dt,
513+
target_disp, target_count, target_dt,
514+
result_addr, ucx_req);
515+
} else {
516+
ret = do_atomic_op_cswap(module, op, target,
517+
origin_addr, origin_count, origin_dt,
518+
target_disp, target_count, target_dt,
519+
result_addr, ucx_req);
520+
}
466521
return ret;
467522
}
468523

@@ -576,11 +631,14 @@ int ompi_osc_ucx_get(void *origin_addr, int origin_count,
576631
}
577632
}
578633

579-
int ompi_osc_ucx_accumulate(const void *origin_addr, int origin_count,
580-
struct ompi_datatype_t *origin_dt,
581-
int target, ptrdiff_t target_disp, int target_count,
582-
struct ompi_datatype_t *target_dt,
583-
struct ompi_op_t *op, struct ompi_win_t *win) {
634+
static
635+
int accumulate_req(const void *origin_addr, int origin_count,
636+
struct ompi_datatype_t *origin_dt,
637+
int target, ptrdiff_t target_disp, int target_count,
638+
struct ompi_datatype_t *target_dt,
639+
struct ompi_op_t *op, struct ompi_win_t *win,
640+
ompi_osc_ucx_request_t *ucx_req) {
641+
584642
ompi_osc_ucx_module_t *module = (ompi_osc_ucx_module_t*) win->w_osc_module;
585643
int ret = OMPI_SUCCESS;
586644

@@ -594,18 +652,10 @@ int ompi_osc_ucx_accumulate(const void *origin_addr, int origin_count,
594652
}
595653

596654
if (module->acc_single_intrinsic) {
597-
if (op == &ompi_mpi_op_replace.op || op == &ompi_mpi_op_sum.op) {
598-
ret = atomic_op_replace_sum(module, op, target,
599-
origin_addr, origin_count, origin_dt,
600-
target_disp, target_count, target_dt,
601-
&(module->req_result));
602-
} else {
603-
ret = atomic_op_cswap(module, op, target,
604-
origin_addr, origin_count, origin_dt,
605-
target_disp, target_count, target_dt,
606-
&(module->req_result));
607-
}
608-
return ret;
655+
return do_atomic_op(module, op, target,
656+
origin_addr, origin_count, origin_dt,
657+
target_disp, target_count, target_dt,
658+
NULL, ucx_req);
609659
}
610660

611661

@@ -712,9 +762,23 @@ int ompi_osc_ucx_accumulate(const void *origin_addr, int origin_count,
712762
free(temp_addr_holder);
713763
}
714764

765+
if (NULL != ucx_req) {
766+
// nothing to wait for, mark request as completed
767+
ompi_request_complete(&ucx_req->super, true);
768+
}
769+
715770
return end_atomicity(module, target);
716771
}
717772

773+
int ompi_osc_ucx_accumulate(const void *origin_addr, int origin_count,
774+
struct ompi_datatype_t *origin_dt,
775+
int target, ptrdiff_t target_disp, int target_count,
776+
struct ompi_datatype_t *target_dt,
777+
struct ompi_op_t *op, struct ompi_win_t *win) {
778+
return accumulate_req(origin_addr, origin_count, origin_dt, target,
779+
target_disp, target_count, target_dt, op, win, NULL);
780+
}
781+
718782
int ompi_osc_ucx_compare_and_swap(const void *origin_addr, const void *compare_addr,
719783
void *result_addr, struct ompi_datatype_t *dt,
720784
int target, ptrdiff_t target_disp,
@@ -813,13 +877,15 @@ int ompi_osc_ucx_fetch_and_op(const void *origin_addr, void *result_addr,
813877
}
814878
}
815879

816-
int ompi_osc_ucx_get_accumulate(const void *origin_addr, int origin_count,
817-
struct ompi_datatype_t *origin_dt,
818-
void *result_addr, int result_count,
819-
struct ompi_datatype_t *result_dt,
820-
int target, ptrdiff_t target_disp,
821-
int target_count, struct ompi_datatype_t *target_dt,
822-
struct ompi_op_t *op, struct ompi_win_t *win) {
880+
static
881+
int get_accumulate_req(const void *origin_addr, int origin_count,
882+
struct ompi_datatype_t *origin_dt,
883+
void *result_addr, int result_count,
884+
struct ompi_datatype_t *result_dt,
885+
int target, ptrdiff_t target_disp,
886+
int target_count, struct ompi_datatype_t *target_dt,
887+
struct ompi_op_t *op, struct ompi_win_t *win,
888+
ompi_osc_ucx_request_t *ucx_req) {
823889
ompi_osc_ucx_module_t *module = (ompi_osc_ucx_module_t*) win->w_osc_module;
824890
int ret = OMPI_SUCCESS;
825891

@@ -829,19 +895,12 @@ int ompi_osc_ucx_get_accumulate(const void *origin_addr, int origin_count,
829895
}
830896

831897
if (module->acc_single_intrinsic) {
832-
if (op == &ompi_mpi_op_replace.op || op == &ompi_mpi_op_sum.op) {
833-
ret = atomic_op_replace_sum(module, op, target,
834-
origin_addr, origin_count, origin_dt,
835-
target_disp, target_count, target_dt, result_addr);
836-
} else {
837-
ret = atomic_op_cswap(module, op, target,
838-
origin_addr, origin_count, origin_dt,
839-
target_disp, target_count, target_dt, result_addr);
840-
}
841-
return ret;
898+
return do_atomic_op(module, op, target,
899+
origin_addr, origin_count, origin_dt,
900+
target_disp, target_count, target_dt,
901+
result_addr, ucx_req);
842902
}
843903

844-
845904
ret = start_atomicity(module, target);
846905
if (ret != OMPI_SUCCESS) {
847906
return ret;
@@ -953,9 +1012,28 @@ int ompi_osc_ucx_get_accumulate(const void *origin_addr, int origin_count,
9531012
}
9541013
}
9551014

1015+
if (NULL != ucx_req) {
1016+
// nothing to wait for, mark request as completed
1017+
ompi_request_complete(&ucx_req->super, true);
1018+
}
1019+
1020+
9561021
return end_atomicity(module, target);
9571022
}
9581023

1024+
int ompi_osc_ucx_get_accumulate(const void *origin_addr, int origin_count,
1025+
struct ompi_datatype_t *origin_dt,
1026+
void *result_addr, int result_count,
1027+
struct ompi_datatype_t *result_dt,
1028+
int target, ptrdiff_t target_disp,
1029+
int target_count, struct ompi_datatype_t *target_dt,
1030+
struct ompi_op_t *op, struct ompi_win_t *win) {
1031+
1032+
return get_accumulate_req(origin_addr, origin_count, origin_dt, result_addr,
1033+
result_count, result_dt, target, target_disp,
1034+
target_count, target_dt, op, win, NULL);
1035+
}
1036+
9591037
int ompi_osc_ucx_rput(const void *origin_addr, int origin_count,
9601038
struct ompi_datatype_t *origin_dt,
9611039
int target, ptrdiff_t target_disp, int target_count,
@@ -1077,14 +1155,13 @@ int ompi_osc_ucx_raccumulate(const void *origin_addr, int origin_count,
10771155
OMPI_OSC_UCX_REQUEST_ALLOC(win, ucx_req);
10781156
assert(NULL != ucx_req);
10791157

1080-
ret = ompi_osc_ucx_accumulate(origin_addr, origin_count, origin_dt, target, target_disp,
1081-
target_count, target_dt, op, win);
1158+
ret = accumulate_req(origin_addr, origin_count, origin_dt, target, target_disp,
1159+
target_count, target_dt, op, win, ucx_req);
10821160
if (ret != OMPI_SUCCESS) {
10831161
OMPI_OSC_UCX_REQUEST_RETURN(ucx_req);
10841162
return ret;
10851163
}
10861164

1087-
ompi_request_complete(&ucx_req->super, true);
10881165
*request = &ucx_req->super;
10891166

10901167
return ret;
@@ -1110,17 +1187,15 @@ int ompi_osc_ucx_rget_accumulate(const void *origin_addr, int origin_count,
11101187
OMPI_OSC_UCX_REQUEST_ALLOC(win, ucx_req);
11111188
assert(NULL != ucx_req);
11121189

1113-
ret = ompi_osc_ucx_get_accumulate(origin_addr, origin_count, origin_datatype,
1114-
result_addr, result_count, result_datatype,
1115-
target, target_disp, target_count,
1116-
target_datatype, op, win);
1190+
ret = get_accumulate_req(origin_addr, origin_count, origin_datatype,
1191+
result_addr, result_count, result_datatype,
1192+
target, target_disp, target_count,
1193+
target_datatype, op, win, ucx_req);
11171194
if (ret != OMPI_SUCCESS) {
11181195
OMPI_OSC_UCX_REQUEST_RETURN(ucx_req);
11191196
return ret;
11201197
}
11211198

1122-
ompi_request_complete(&ucx_req->super, true);
1123-
11241199
*request = &ucx_req->super;
11251200

11261201
return ret;

0 commit comments

Comments
 (0)