Skip to content

Commit 73a1834

Browse files
committed
UCX osc: add support for acc_single_intrinsic info key / mca param
Signed-off-by: Joseph Schuchart <[email protected]>
1 parent e1e8b2a commit 73a1834

File tree

3 files changed

+201
-9
lines changed

3 files changed

+201
-9
lines changed

ompi/mca/osc/ucx/osc_ucx.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ typedef struct ompi_osc_ucx_component {
3434
int num_incomplete_req_ops;
3535
int num_modules;
3636
bool no_locks; /* Default value of the no_locks info key for new windows */
37+
bool acc_single_intrinsic;
3738
unsigned int priority;
3839
} ompi_osc_ucx_component_t;
3940

@@ -115,6 +116,7 @@ typedef struct ompi_osc_ucx_module {
115116
int *start_grp_ranks;
116117
bool lock_all_is_nocheck;
117118
bool no_locks;
119+
bool acc_single_intrinsic;
118120
opal_common_ucx_ctx_t *ctx;
119121
opal_common_ucx_wpmem_t *mem;
120122
opal_common_ucx_wpmem_t *state_mem;

ompi/mca/osc/ucx/osc_ucx_comm.c

Lines changed: 187 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -323,6 +323,149 @@ static inline int get_dynamic_win_info(uint64_t remote_addr, ompi_osc_ucx_module
323323
return ret;
324324
}
325325

326+
static int atomic_op_replace_sum(
327+
ompi_osc_ucx_module_t *module,
328+
struct ompi_op_t *op,
329+
int target,
330+
const void *origin_addr,
331+
int origin_count,
332+
struct ompi_datatype_t *origin_dt,
333+
ptrdiff_t target_disp,
334+
int target_count,
335+
struct ompi_datatype_t *target_dt,
336+
void *result_addr)
337+
{
338+
int ret = OMPI_SUCCESS;
339+
size_t origin_dt_bytes;
340+
size_t target_dt_bytes;
341+
ompi_datatype_type_size(origin_dt, &origin_dt_bytes);
342+
ompi_datatype_type_size(target_dt, &target_dt_bytes);
343+
344+
if (origin_dt_bytes > sizeof(uint64_t) ||
345+
origin_dt_bytes != target_dt_bytes ||
346+
target_count != origin_count) {
347+
return OMPI_ERR_NOT_SUPPORTED;
348+
}
349+
350+
uint64_t remote_addr = (module->addrs[target]) + target_disp * OSC_UCX_GET_DISP(module, target);
351+
352+
if (module->flavor == MPI_WIN_FLAVOR_DYNAMIC) {
353+
ret = get_dynamic_win_info(remote_addr, module, target);
354+
if (ret != OMPI_SUCCESS) {
355+
return ret;
356+
}
357+
}
358+
359+
ucp_atomic_fetch_op_t opcode;
360+
if (op == &ompi_mpi_op_replace.op) {
361+
opcode = UCP_ATOMIC_FETCH_OP_SWAP;
362+
} else {
363+
opcode = UCP_ATOMIC_FETCH_OP_FADD;
364+
}
365+
366+
for (int i = 0; i < origin_count; ++i) {
367+
uint64_t value = 0;
368+
memcpy(&value, origin_addr, origin_dt_bytes);
369+
ret = opal_common_ucx_wpmem_fetch_nb(module->mem, opcode, value, target,
370+
result_addr ? result_addr : &(module->req_result),
371+
origin_dt_bytes, remote_addr, NULL, NULL);
372+
373+
// advance origin and remote address
374+
origin_addr = (void*)((intptr_t)origin_addr + origin_dt_bytes);
375+
remote_addr += origin_dt_bytes;
376+
if (result_addr) {
377+
result_addr = (void*)((intptr_t)result_addr + origin_dt_bytes);
378+
}
379+
}
380+
381+
return ret;
382+
}
383+
384+
static int atomic_op_cswap(
385+
ompi_osc_ucx_module_t *module,
386+
struct ompi_op_t *op,
387+
int target,
388+
const void *origin_addr,
389+
int origin_count,
390+
struct ompi_datatype_t *origin_dt,
391+
ptrdiff_t target_disp,
392+
int target_count,
393+
struct ompi_datatype_t *target_dt,
394+
void *result_addr)
395+
{
396+
int ret = OMPI_SUCCESS;
397+
size_t origin_dt_bytes;
398+
size_t target_dt_bytes;
399+
ompi_datatype_type_size(origin_dt, &origin_dt_bytes);
400+
ompi_datatype_type_size(target_dt, &target_dt_bytes);
401+
402+
if (origin_dt_bytes > sizeof(uint64_t) ||
403+
origin_dt_bytes != target_dt_bytes ||
404+
target_count != origin_count) {
405+
return OMPI_ERR_NOT_SUPPORTED;
406+
}
407+
408+
uint64_t remote_addr = (module->addrs[target]) + target_disp * OSC_UCX_GET_DISP(module, target);
409+
410+
if (module->flavor == MPI_WIN_FLAVOR_DYNAMIC) {
411+
ret = get_dynamic_win_info(remote_addr, module, target);
412+
if (ret != OMPI_SUCCESS) {
413+
return ret;
414+
}
415+
}
416+
417+
for (int i = 0; i < origin_count; ++i) {
418+
419+
uint64_t tmp_val;
420+
do {
421+
uint64_t target_val = 0;
422+
423+
// get the value from the origin
424+
ret = opal_common_ucx_wpmem_putget(module->mem, OPAL_COMMON_UCX_GET,
425+
target, &target_val, origin_dt_bytes,
426+
remote_addr);
427+
if (ret != OMPI_SUCCESS) {
428+
return ret;
429+
}
430+
431+
ret = opal_common_ucx_wpmem_flush(module->mem, OPAL_COMMON_UCX_SCOPE_EP, target);
432+
if (ret != OMPI_SUCCESS) {
433+
return ret;
434+
}
435+
436+
tmp_val = target_val;
437+
// compute the result value
438+
ompi_op_reduce(op, (void *)origin_addr, &tmp_val, 1, origin_dt);
439+
440+
// compare-and-swap the resulting value
441+
ret = opal_common_ucx_wpmem_cmpswp(module->mem, target_val, tmp_val,
442+
target, &tmp_val, origin_dt_bytes,
443+
remote_addr);
444+
if (ret != OMPI_SUCCESS) {
445+
return ret;
446+
}
447+
448+
// check whether the conditional swap was successful
449+
if (tmp_val == target_val) {
450+
break;
451+
}
452+
453+
} while (1);
454+
455+
// store the result if necessary
456+
if (NULL != result_addr) {
457+
memcpy(result_addr, &tmp_val, origin_dt_bytes);
458+
result_addr = (void*)((intptr_t)result_addr + origin_dt_bytes);
459+
}
460+
// advance origin and remote address
461+
origin_addr = (void*)((intptr_t)origin_addr + origin_dt_bytes);
462+
remote_addr += origin_dt_bytes;
463+
}
464+
465+
return ret;
466+
}
467+
468+
326469
int ompi_osc_ucx_put(const void *origin_addr, int origin_count, struct ompi_datatype_t *origin_dt,
327470
int target, ptrdiff_t target_disp, int target_count,
328471
struct ompi_datatype_t *target_dt, struct ompi_win_t *win) {
@@ -449,6 +592,22 @@ int ompi_osc_ucx_accumulate(const void *origin_addr, int origin_count,
449592
return ret;
450593
}
451594

595+
if (module->acc_single_intrinsic) {
596+
if (op == &ompi_mpi_op_replace.op || op == &ompi_mpi_op_sum.op) {
597+
ret = atomic_op_replace_sum(module, op, target,
598+
origin_addr, origin_count, origin_dt,
599+
target_disp, target_count, target_dt,
600+
&(module->req_result));
601+
} else {
602+
ret = atomic_op_cswap(module, op, target,
603+
origin_addr, origin_count, origin_dt,
604+
target_disp, target_count, target_dt,
605+
&(module->req_result));
606+
}
607+
return ret;
608+
}
609+
610+
452611
ret = start_atomicity(module, target);
453612
if (ret != OMPI_SUCCESS) {
454613
return ret;
@@ -569,9 +728,11 @@ int ompi_osc_ucx_compare_and_swap(const void *origin_addr, const void *compare_a
569728
return ret;
570729
}
571730

572-
ret = start_atomicity(module, target);
573-
if (ret != OMPI_SUCCESS) {
574-
return ret;
731+
if (!module->acc_single_intrinsic) {
732+
ret = start_atomicity(module, target);
733+
if (ret != OMPI_SUCCESS) {
734+
return ret;
735+
}
575736
}
576737

577738
if (module->flavor == MPI_WIN_FLAVOR_DYNAMIC) {
@@ -585,7 +746,8 @@ int ompi_osc_ucx_compare_and_swap(const void *origin_addr, const void *compare_a
585746
ret = opal_common_ucx_wpmem_cmpswp(module->mem,*(uint64_t *)compare_addr,
586747
*(uint64_t *)origin_addr, target,
587748
result_addr, dt_bytes, remote_addr);
588-
if (ret != OMPI_SUCCESS) {
749+
750+
if (module->acc_single_intrinsic) {
589751
return ret;
590752
}
591753

@@ -611,9 +773,11 @@ int ompi_osc_ucx_fetch_and_op(const void *origin_addr, void *result_addr,
611773
ucp_atomic_fetch_op_t opcode;
612774
size_t dt_bytes;
613775

614-
ret = start_atomicity(module, target);
615-
if (ret != OMPI_SUCCESS) {
616-
return ret;
776+
if (!module->acc_single_intrinsic) {
777+
ret = start_atomicity(module, target);
778+
if (ret != OMPI_SUCCESS) {
779+
return ret;
780+
}
617781
}
618782

619783
if (module->flavor == MPI_WIN_FLAVOR_DYNAMIC) {
@@ -636,7 +800,8 @@ int ompi_osc_ucx_fetch_and_op(const void *origin_addr, void *result_addr,
636800

637801
ret = opal_common_ucx_wpmem_fetch(module->mem, opcode, value, target,
638802
(void *)result_addr, dt_bytes, remote_addr);
639-
if (ret != OMPI_SUCCESS) {
803+
804+
if (module->acc_single_intrinsic) {
640805
return ret;
641806
}
642807

@@ -662,6 +827,20 @@ int ompi_osc_ucx_get_accumulate(const void *origin_addr, int origin_count,
662827
return ret;
663828
}
664829

830+
if (module->acc_single_intrinsic) {
831+
if (op == &ompi_mpi_op_replace.op || op == &ompi_mpi_op_sum.op) {
832+
ret = atomic_op_replace_sum(module, op, target,
833+
origin_addr, origin_count, origin_dt,
834+
target_disp, target_count, target_dt, result_addr);
835+
} else {
836+
ret = atomic_op_cswap(module, op, target,
837+
origin_addr, origin_count, origin_dt,
838+
target_disp, target_count, target_dt, result_addr);
839+
}
840+
return ret;
841+
}
842+
843+
665844
ret = start_atomicity(module, target);
666845
if (ret != OMPI_SUCCESS) {
667846
return ret;

ompi/mca/osc/ucx/osc_ucx_component.c

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,8 @@ ompi_osc_ucx_component_t mca_osc_ucx_component = {
7272
.wpool = NULL,
7373
.env_initialized = false,
7474
.num_incomplete_req_ops = 0,
75-
.num_modules = 0
75+
.num_modules = 0,
76+
.acc_single_intrinsic = false
7677
};
7778

7879
ompi_osc_ucx_module_t ompi_osc_ucx_module_template = {
@@ -167,6 +168,15 @@ static int component_register(void) {
167168
MCA_BASE_VAR_SCOPE_GROUP, &mca_osc_ucx_component.no_locks);
168169
free(description_str);
169170

171+
mca_osc_ucx_component.acc_single_intrinsic = false;
172+
opal_asprintf(&description_str, "Enable optimizations for MPI_Fetch_and_op, MPI_Accumulate, etc for codes "
173+
"that will not use anything more than a single predefined datatype (default: %s)",
174+
mca_osc_ucx_component.acc_single_intrinsic ? "true" : "false");
175+
(void) mca_base_component_var_register(&mca_osc_ucx_component.super.osc_version, "acc_single_intrinsic",
176+
description_str, MCA_BASE_VAR_TYPE_BOOL, NULL, 0, 0, OPAL_INFO_LVL_5,
177+
MCA_BASE_VAR_SCOPE_GROUP, &mca_osc_ucx_component.acc_single_intrinsic);
178+
free(description_str);
179+
170180
opal_common_ucx_mca_var_register(&mca_osc_ucx_component.super.osc_version);
171181

172182
return OMPI_SUCCESS;
@@ -389,6 +399,7 @@ static int component_select(struct ompi_win_t *win, void **base, size_t size, in
389399
module->flavor = flavor;
390400
module->size = size;
391401
module->no_locks = check_config_value_bool ("no_locks", info);
402+
module->acc_single_intrinsic = check_config_value_bool ("acc_single_intrinsic", info);
392403

393404
/* share everyone's displacement units. Only do an allgather if
394405
strictly necessary, since it requires O(p) state. */

0 commit comments

Comments
 (0)