Skip to content

Commit 4d7a385

Browse files
committed
UCX osc: Use accumulate for operations/datatypes that are not covered by UCX
Signed-off-by: Joseph Schuchart <[email protected]>
1 parent 899f58c commit 4d7a385

File tree

1 file changed

+30
-139
lines changed

1 file changed

+30
-139
lines changed

ompi/mca/osc/ucx/osc_ucx_comm.c

Lines changed: 30 additions & 139 deletions
Original file line numberDiff line numberDiff line change
@@ -323,7 +323,25 @@ static inline int get_dynamic_win_info(uint64_t remote_addr, ompi_osc_ucx_module
323323
return ret;
324324
}
325325

326-
static int do_atomic_op_replace_sum(
326+
static inline
327+
bool use_ucx_op(struct ompi_op_t *op, struct ompi_datatype_t *origin_dt)
328+
{
329+
330+
if (op == &ompi_mpi_op_replace.op ||
331+
op == &ompi_mpi_op_sum.op ||
332+
op == &ompi_mpi_op_no_op.op) {
333+
size_t dt_bytes;
334+
ompi_datatype_type_size(origin_dt, &dt_bytes);
335+
if (ompi_datatype_is_predefined(origin_dt) &&
336+
sizeof(uint64_t) >= dt_bytes) {
337+
return true;
338+
}
339+
}
340+
341+
return false;
342+
}
343+
344+
static int do_atomic_op_intrinsic(
327345
ompi_osc_ucx_module_t *module,
328346
struct ompi_op_t *op,
329347
int target,
@@ -342,7 +360,7 @@ static int do_atomic_op_replace_sum(
342360
ompi_datatype_type_size(origin_dt, &origin_dt_bytes);
343361
ompi_datatype_type_size(target_dt, &target_dt_bytes);
344362

345-
if (origin_dt_bytes > sizeof(uint64_t) ||
363+
if (sizeof(uint64_t) > origin_dt_bytes ||
346364
origin_dt_bytes != target_dt_bytes ||
347365
target_count != origin_count) {
348366
return OMPI_ERR_NOT_SUPPORTED;
@@ -409,133 +427,6 @@ static int do_atomic_op_replace_sum(
409427
return ret;
410428
}
411429

412-
static int do_atomic_op_cswap(
413-
ompi_osc_ucx_module_t *module,
414-
struct ompi_op_t *op,
415-
int target,
416-
const void *origin_addr,
417-
int origin_count,
418-
struct ompi_datatype_t *origin_dt,
419-
ptrdiff_t target_disp,
420-
int target_count,
421-
struct ompi_datatype_t *target_dt,
422-
void *result_addr,
423-
ompi_osc_ucx_request_t *ucx_req)
424-
{
425-
int ret = OMPI_SUCCESS;
426-
size_t origin_dt_bytes;
427-
size_t target_dt_bytes;
428-
ompi_datatype_type_size(origin_dt, &origin_dt_bytes);
429-
ompi_datatype_type_size(target_dt, &target_dt_bytes);
430-
431-
if (origin_dt_bytes > sizeof(uint64_t) ||
432-
origin_dt_bytes != target_dt_bytes ||
433-
target_count != origin_count) {
434-
return OMPI_ERR_NOT_SUPPORTED;
435-
}
436-
437-
uint64_t remote_addr = (module->addrs[target]) + target_disp * OSC_UCX_GET_DISP(module, target);
438-
439-
if (module->flavor == MPI_WIN_FLAVOR_DYNAMIC) {
440-
ret = get_dynamic_win_info(remote_addr, module, target);
441-
if (ret != OMPI_SUCCESS) {
442-
return ret;
443-
}
444-
}
445-
446-
for (int i = 0; i < origin_count; ++i) {
447-
448-
uint64_t tmp_val;
449-
uint64_t target_val = 0;
450-
451-
// get the value from the origin
452-
ret = opal_common_ucx_wpmem_putget(module->mem, OPAL_COMMON_UCX_GET,
453-
target, &target_val, origin_dt_bytes,
454-
remote_addr);
455-
if (ret != OMPI_SUCCESS) {
456-
return ret;
457-
}
458-
459-
ret = opal_common_ucx_wpmem_flush(module->mem, OPAL_COMMON_UCX_SCOPE_EP, target);
460-
if (ret != OMPI_SUCCESS) {
461-
return ret;
462-
}
463-
464-
/* JS: move this loop into the request to overlap multiple cas operations? */
465-
do {
466-
467-
tmp_val = target_val;
468-
// compute the result value
469-
ompi_op_reduce(op, (void *)origin_addr, &tmp_val, 1, origin_dt);
470-
471-
// compare-and-swap the resulting value
472-
ret = opal_common_ucx_wpmem_cmpswp(module->mem, target_val, tmp_val,
473-
target, &tmp_val, origin_dt_bytes,
474-
remote_addr);
475-
if (ret != OMPI_SUCCESS) {
476-
return ret;
477-
}
478-
479-
// check whether the conditional swap was successful
480-
if (tmp_val == target_val) {
481-
break;
482-
}
483-
484-
target_val = tmp_val;
485-
486-
} while (1);
487-
488-
// store the result if necessary
489-
if (NULL != result_addr) {
490-
memcpy(result_addr, &tmp_val, origin_dt_bytes);
491-
result_addr = (void*)((intptr_t)result_addr + origin_dt_bytes);
492-
}
493-
// advance origin and remote address
494-
origin_addr = (void*)((intptr_t)origin_addr + origin_dt_bytes);
495-
remote_addr += origin_dt_bytes;
496-
}
497-
498-
if (NULL != ucx_req) {
499-
// nothing to wait for so mark the request as completed
500-
ompi_request_complete(&ucx_req->super, true);
501-
}
502-
503-
return ret;
504-
}
505-
506-
static inline
507-
int do_atomic_op(
508-
ompi_osc_ucx_module_t *module,
509-
struct ompi_op_t *op,
510-
int target,
511-
const void *origin_addr,
512-
int origin_count,
513-
struct ompi_datatype_t *origin_dt,
514-
ptrdiff_t target_disp,
515-
int target_count,
516-
struct ompi_datatype_t *target_dt,
517-
void *result_addr,
518-
ompi_osc_ucx_request_t *ucx_req)
519-
{
520-
int ret;
521-
522-
if (op == &ompi_mpi_op_replace.op ||
523-
op == &ompi_mpi_op_sum.op ||
524-
op == &ompi_mpi_op_no_op.op) {
525-
ret = do_atomic_op_replace_sum(module, op, target,
526-
origin_addr, origin_count, origin_dt,
527-
target_disp, target_count, target_dt,
528-
result_addr, ucx_req);
529-
} else {
530-
ret = do_atomic_op_cswap(module, op, target,
531-
origin_addr, origin_count, origin_dt,
532-
target_disp, target_count, target_dt,
533-
result_addr, ucx_req);
534-
}
535-
return ret;
536-
}
537-
538-
539430
int ompi_osc_ucx_put(const void *origin_addr, int origin_count, struct ompi_datatype_t *origin_dt,
540431
int target, ptrdiff_t target_disp, int target_count,
541432
struct ompi_datatype_t *target_dt, struct ompi_win_t *win) {
@@ -665,11 +556,11 @@ int accumulate_req(const void *origin_addr, int origin_count,
665556
return ret;
666557
}
667558

668-
if (module->acc_single_intrinsic) {
669-
return do_atomic_op(module, op, target,
670-
origin_addr, origin_count, origin_dt,
671-
target_disp, target_count, target_dt,
672-
NULL, ucx_req);
559+
if (module->acc_single_intrinsic && use_ucx_op(op, origin_dt)) {
560+
return do_atomic_op_intrinsic(module, op, target,
561+
origin_addr, origin_count, origin_dt,
562+
target_disp, target_count, target_dt,
563+
NULL, ucx_req);
673564
}
674565

675566

@@ -923,11 +814,11 @@ int get_accumulate_req(const void *origin_addr, int origin_count,
923814
return ret;
924815
}
925816

926-
if (module->acc_single_intrinsic) {
927-
return do_atomic_op(module, op, target,
928-
origin_addr, origin_count, origin_dt,
929-
target_disp, target_count, target_dt,
930-
result_addr, ucx_req);
817+
if (module->acc_single_intrinsic && use_ucx_op(op, origin_dt)) {
818+
return do_atomic_op_intrinsic(module, op, target,
819+
origin_addr, origin_count, origin_dt,
820+
target_disp, target_count, target_dt,
821+
result_addr, ucx_req);
931822
}
932823

933824
ret = start_atomicity(module, target);

0 commit comments

Comments
 (0)