Skip to content

Commit 434c905

Browse files
committed
UCX osc: fall back to get-compare-put for unsupported datatypes
Signed-off-by: Joseph Schuchart <[email protected]>
1 parent 7d5a6e3 commit 434c905

File tree

2 files changed

+64
-19
lines changed

2 files changed

+64
-19
lines changed

ompi/mca/osc/ucx/osc_ucx_comm.c

Lines changed: 63 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -722,6 +722,36 @@ int ompi_osc_ucx_accumulate(const void *origin_addr, int origin_count,
722722
target_disp, target_count, target_dt, op, win, NULL);
723723
}
724724

725+
static int
726+
do_atomic_compare_and_swap(const void *origin_addr, const void *compare_addr,
727+
void *result_addr, struct ompi_datatype_t *dt,
728+
int target, uint64_t remote_addr,
729+
ompi_osc_ucx_module_t *module)
730+
{
731+
int ret;
732+
bool lock_acquired = false;
733+
size_t dt_bytes;
734+
if (!module->acc_single_intrinsic) {
735+
ret = start_atomicity(module, target, &lock_acquired);
736+
if (ret != OMPI_SUCCESS) {
737+
return ret;
738+
}
739+
}
740+
741+
ompi_datatype_type_size(dt, &dt_bytes);
742+
uint64_t compare_val = opal_common_ucx_load_uint64(compare_addr, dt_bytes);
743+
uint64_t value = opal_common_ucx_load_uint64(origin_addr, dt_bytes);
744+
ret = opal_common_ucx_wpmem_cmpswp_nb(module->mem, compare_val, value, target,
745+
result_addr, dt_bytes, remote_addr,
746+
NULL, NULL);
747+
748+
if (module->acc_single_intrinsic) {
749+
return ret;
750+
}
751+
752+
return end_atomicity(module, target, lock_acquired, NULL);
753+
}
754+
725755
int ompi_osc_ucx_compare_and_swap(const void *origin_addr, const void *compare_addr,
726756
void *result_addr, struct ompi_datatype_t *dt,
727757
int target, ptrdiff_t target_disp,
@@ -732,40 +762,55 @@ int ompi_osc_ucx_compare_and_swap(const void *origin_addr, const void *compare_a
732762
int ret = OMPI_SUCCESS;
733763
bool lock_acquired = false;
734764

735-
ompi_datatype_type_size(dt, &dt_bytes);
736-
if (sizeof(uint64_t) < dt_bytes) {
737-
return OMPI_ERR_NOT_SUPPORTED;
738-
}
739-
740765
ret = check_sync_state(module, target, false);
741766
if (ret != OMPI_SUCCESS) {
742767
return ret;
743768
}
744769

745-
if (!module->acc_single_intrinsic) {
746-
ret = start_atomicity(module, target, &lock_acquired);
747-
if (ret != OMPI_SUCCESS) {
748-
return ret;
749-
}
750-
}
751-
752770
if (module->flavor == MPI_WIN_FLAVOR_DYNAMIC) {
753771
ret = get_dynamic_win_info(remote_addr, module, target);
754772
if (ret != OMPI_SUCCESS) {
755773
return ret;
756774
}
757775
}
758776

759-
uint64_t compare_val = opal_common_ucx_load_uint64(compare_addr, dt_bytes);
760-
uint64_t value = opal_common_ucx_load_uint64(origin_addr, dt_bytes);
761-
ret = opal_common_ucx_wpmem_cmpswp_nb(module->mem, compare_val, value, target,
762-
result_addr, dt_bytes, remote_addr,
763-
NULL, NULL);
777+
ompi_datatype_type_size(dt, &dt_bytes);
778+
if (4 == dt_bytes || 8 == dt_bytes) {
779+
// fast path using UCX atomic operations
780+
return do_atomic_compare_and_swap(origin_addr, compare_addr,
781+
result_addr, dt, target,
782+
remote_addr, module);
783+
}
764784

765-
if (module->acc_single_intrinsic) {
785+
/* fall back to get-compare-put */
786+
787+
ret = start_atomicity(module, target, &lock_acquired);
788+
if (ret != OMPI_SUCCESS) {
766789
return ret;
767790
}
768791

792+
ret = opal_common_ucx_wpmem_putget(module->mem, OPAL_COMMON_UCX_GET, target,
793+
&result_addr, dt_bytes, remote_addr);
794+
if (OPAL_SUCCESS != ret) {
795+
OSC_UCX_VERBOSE(1, "opal_common_ucx_mem_putget failed: %d", ret);
796+
return OMPI_ERROR;
797+
}
798+
799+
ret = opal_common_ucx_wpmem_flush(module->mem, OPAL_COMMON_UCX_SCOPE_EP, target);
800+
if (ret != OPAL_SUCCESS) {
801+
return ret;
802+
}
803+
804+
if (0 == memcmp(result_addr, compare_addr, dt_bytes)) {
805+
// write the new value
806+
ret = opal_common_ucx_wpmem_putget(module->mem, OPAL_COMMON_UCX_PUT, target,
807+
(void*)origin_addr, dt_bytes, remote_addr);
808+
if (OPAL_SUCCESS != ret) {
809+
OSC_UCX_VERBOSE(1, "opal_common_ucx_mem_putget failed: %d", ret);
810+
return OMPI_ERROR;
811+
}
812+
}
813+
769814
return end_atomicity(module, target, lock_acquired, NULL);
770815
}
771816

opal/mca/common/ucx/common_ucx.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ OPAL_DECLSPEC void opal_common_ucx_mca_var_register(const mca_base_component_t *
120120
* Load an integer value of \c size bytes from \c ptr and cast it to uint64_t.
121121
*/
122122
static inline
123-
uint64_t opal_common_ucx_load_uint64(void *ptr, size_t size)
123+
uint64_t opal_common_ucx_load_uint64(const void *ptr, size_t size)
124124
{
125125
if (sizeof(uint8_t) == size) {
126126
return *(uint8_t*)ptr;

0 commit comments

Comments
 (0)