Skip to content

Commit d8696aa

Browse files
committed
UCX osc: centralize decision on whether to use AMOs
Signed-off-by: Joseph Schuchart <[email protected]>
1 parent 427d4bd commit d8696aa

File tree

1 file changed

+46
-38
lines changed

1 file changed

+46
-38
lines changed

ompi/mca/osc/ucx/osc_ucx_comm.c

Lines changed: 46 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,7 @@ static inline int start_atomicity(
264264
return OMPI_ERROR;
265265
}
266266
if (result_value == TARGET_LOCK_UNLOCKED) {
267-
return OMPI_SUCCESS;
267+
break;
268268
}
269269

270270
ucp_worker_progress(mca_osc_ucx_component.wpool->dflt_worker);
@@ -274,6 +274,8 @@ static inline int start_atomicity(
274274
} else {
275275
*lock_acquired = false;
276276
}
277+
278+
return OMPI_SUCCESS;
277279
}
278280

279281
static inline int end_atomicity(
@@ -362,16 +364,30 @@ static inline int get_dynamic_win_info(uint64_t remote_addr, ompi_osc_ucx_module
362364
}
363365

364366
static inline
365-
bool use_ucx_op(struct ompi_op_t *op, struct ompi_datatype_t *origin_dt)
367+
bool use_atomic_op(
368+
ompi_osc_ucx_module_t *module,
369+
struct ompi_op_t *op,
370+
struct ompi_datatype_t *origin_dt,
371+
struct ompi_datatype_t *target_dt,
372+
int origin_count,
373+
int target_count)
366374
{
367375

368-
if (op == &ompi_mpi_op_replace.op ||
369-
op == &ompi_mpi_op_sum.op ||
370-
op == &ompi_mpi_op_no_op.op) {
371-
size_t dt_bytes;
372-
ompi_datatype_type_size(origin_dt, &dt_bytes);
373-
if (ompi_datatype_is_predefined(origin_dt) &&
374-
sizeof(uint64_t) >= dt_bytes) {
376+
if (module->acc_single_intrinsic &&
377+
ompi_datatype_is_predefined(origin_dt) &&
378+
origin_count == 1 &&
379+
(op == &ompi_mpi_op_replace.op ||
380+
op == &ompi_mpi_op_sum.op ||
381+
op == &ompi_mpi_op_no_op.op)) {
382+
size_t origin_dt_bytes;
383+
size_t target_dt_bytes;
384+
ompi_datatype_type_size(origin_dt, &origin_dt_bytes);
385+
ompi_datatype_type_size(target_dt, &target_dt_bytes);
386+
/* UCX only supports 32 and 64-bit operands atm */
387+
if (sizeof(uint64_t) >= origin_dt_bytes &&
388+
sizeof(uint32_t) <= origin_dt_bytes &&
389+
origin_dt_bytes == target_dt_bytes &&
390+
origin_count == target_count) {
375391
return true;
376392
}
377393
}
@@ -384,25 +400,15 @@ static int do_atomic_op_intrinsic(
384400
struct ompi_op_t *op,
385401
int target,
386402
const void *origin_addr,
387-
int origin_count,
388-
struct ompi_datatype_t *origin_dt,
403+
int count,
404+
struct ompi_datatype_t *dt,
389405
ptrdiff_t target_disp,
390-
int target_count,
391-
struct ompi_datatype_t *target_dt,
392406
void *result_addr,
393407
ompi_osc_ucx_request_t *ucx_req)
394408
{
395409
int ret = OMPI_SUCCESS;
396410
size_t origin_dt_bytes;
397-
size_t target_dt_bytes;
398-
ompi_datatype_type_size(origin_dt, &origin_dt_bytes);
399-
ompi_datatype_type_size(target_dt, &target_dt_bytes);
400-
401-
if (sizeof(uint64_t) > origin_dt_bytes ||
402-
origin_dt_bytes != target_dt_bytes ||
403-
target_count != origin_count) {
404-
return OMPI_ERR_NOT_SUPPORTED;
405-
}
411+
ompi_datatype_type_size(dt, &origin_dt_bytes);
406412

407413
uint64_t remote_addr = (module->addrs[target]) + target_disp * OSC_UCX_GET_DISP(module, target);
408414

@@ -430,9 +436,9 @@ static int do_atomic_op_intrinsic(
430436
if( result_addr ) {
431437
output_addr = result_addr;
432438
}
433-
for (int i = 0; i < origin_count; ++i) {
439+
for (int i = 0; i < count; ++i) {
434440
uint64_t value = 0;
435-
if ((origin_count - 1) == i && NULL != ucx_req) {
441+
if ((count - 1) == i && NULL != ucx_req) {
436442
// the last item is used to feed the request, if needed
437443
user_req_cb = &req_completion;
438444
user_req_ptr = ucx_req;
@@ -596,11 +602,11 @@ int accumulate_req(const void *origin_addr, int origin_count,
596602
return ret;
597603
}
598604

599-
if (module->acc_single_intrinsic && use_ucx_op(op, origin_dt)) {
605+
/* rely on UCX network atomics if the user told us that it safe */
606+
if (use_atomic_op(module, op, origin_dt, target_dt, origin_count, target_count)) {
600607
return do_atomic_op_intrinsic(module, op, target,
601608
origin_addr, origin_count, origin_dt,
602-
target_disp, target_count, target_dt,
603-
NULL, ucx_req);
609+
target_disp, NULL, ucx_req);
604610
}
605611

606612
ret = start_atomicity(module, target, &lock_acquired);
@@ -726,14 +732,21 @@ int ompi_osc_ucx_compare_and_swap(const void *origin_addr, const void *compare_a
726732
int ret = OMPI_SUCCESS;
727733
bool lock_acquired = false;
728734

735+
ompi_datatype_type_size(dt, &dt_bytes);
736+
if (sizeof(uint64_t) < dt_bytes) {
737+
return OMPI_ERR_NOT_SUPPORTED;
738+
}
739+
729740
ret = check_sync_state(module, target, false);
730741
if (ret != OMPI_SUCCESS) {
731742
return ret;
732743
}
733744

734-
ret = start_atomicity(module, target, &lock_acquired);
735-
if (ret != OMPI_SUCCESS) {
736-
return ret;
745+
if (!module->acc_single_intrinsic) {
746+
ret = start_atomicity(module, target, &lock_acquired);
747+
if (ret != OMPI_SUCCESS) {
748+
return ret;
749+
}
737750
}
738751

739752
if (module->flavor == MPI_WIN_FLAVOR_DYNAMIC) {
@@ -743,11 +756,6 @@ int ompi_osc_ucx_compare_and_swap(const void *origin_addr, const void *compare_a
743756
}
744757
}
745758

746-
ompi_datatype_type_size(dt, &dt_bytes);
747-
if (sizeof(uint64_t) < dt_bytes) {
748-
return OMPI_ERR_NOT_SUPPORTED;
749-
}
750-
751759
uint64_t compare_val;
752760
memcpy(&compare_val, compare_addr, dt_bytes);
753761
memcpy(result_addr, origin_addr, dt_bytes);
@@ -841,11 +849,11 @@ int get_accumulate_req(const void *origin_addr, int origin_count,
841849
return ret;
842850
}
843851

844-
if (module->acc_single_intrinsic && use_ucx_op(op, origin_dt)) {
852+
/* rely on UCX network atomics if the user told us that it safe */
853+
if (use_atomic_op(module, op, origin_dt, target_dt, origin_count, target_count)) {
845854
return do_atomic_op_intrinsic(module, op, target,
846855
origin_addr, origin_count, origin_dt,
847-
target_disp, target_count, target_dt,
848-
result_addr, ucx_req);
856+
target_disp, result_addr, ucx_req);
849857
}
850858

851859
ret = start_atomicity(module, target, &lock_acquired);

0 commit comments

Comments
 (0)