Skip to content

Commit 7d5a6e3

Browse files
committed
UCX osc: safely load/store 64bit integer from variable size pointer
Signed-off-by: Joseph Schuchart <[email protected]>
1 parent 824afac commit 7d5a6e3

File tree

2 files changed

+46
-24
lines changed

2 files changed

+46
-24
lines changed

ompi/mca/osc/ucx/osc_ucx_comm.c

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -454,7 +454,7 @@ static int do_atomic_op_intrinsic(
454454
if (is_no_op) {
455455
value = 0;
456456
} else {
457-
memcpy(&value, origin_addr, origin_dt_bytes);
457+
value = opal_common_ucx_load_uint64(origin_addr, origin_dt_bytes);
458458
}
459459
ret = opal_common_ucx_wpmem_fetch_nb(module->mem, opcode, value, target,
460460
output_addr, origin_dt_bytes, remote_addr,
@@ -756,13 +756,11 @@ int ompi_osc_ucx_compare_and_swap(const void *origin_addr, const void *compare_a
756756
}
757757
}
758758

759-
uint64_t compare_val;
760-
memcpy(&compare_val, compare_addr, dt_bytes);
761-
memcpy(result_addr, origin_addr, dt_bytes);
762-
ret = opal_common_ucx_wpmem_fetch_nb(module->mem, UCP_ATOMIC_FETCH_OP_CSWAP,
763-
compare_val, target,
764-
result_addr, dt_bytes, remote_addr,
765-
NULL, NULL);
759+
uint64_t compare_val = opal_common_ucx_load_uint64(compare_addr, dt_bytes);
760+
uint64_t value = opal_common_ucx_load_uint64(origin_addr, dt_bytes);
761+
ret = opal_common_ucx_wpmem_cmpswp_nb(module->mem, compare_val, value, target,
762+
result_addr, dt_bytes, remote_addr,
763+
NULL, NULL);
766764

767765
if (module->acc_single_intrinsic) {
768766
return ret;
@@ -785,8 +783,8 @@ int ompi_osc_ucx_fetch_and_op(const void *origin_addr, void *result_addr,
785783

786784
if (op == &ompi_mpi_op_no_op.op || op == &ompi_mpi_op_replace.op ||
787785
op == &ompi_mpi_op_sum.op) {
786+
uint64_t value;
788787
uint64_t remote_addr = (module->addrs[target]) + target_disp * OSC_UCX_GET_DISP(module, target);
789-
uint64_t value = origin_addr ? *(uint64_t *)origin_addr : 0;
790788
ucp_atomic_fetch_op_t opcode;
791789
size_t dt_bytes;
792790
bool lock_acquired = false;
@@ -805,7 +803,7 @@ int ompi_osc_ucx_fetch_and_op(const void *origin_addr, void *result_addr,
805803
}
806804
}
807805

808-
ompi_datatype_type_size(dt, &dt_bytes);
806+
value = origin_addr ? opal_common_ucx_load_uint64(origin_addr, dt_bytes) : 0;
809807

810808
if (op == &ompi_mpi_op_replace.op) {
811809
opcode = UCP_ATOMIC_FETCH_OP_SWAP;

opal/mca/common/ucx/common_ucx.h

Lines changed: 38 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,42 @@ OPAL_DECLSPEC int opal_common_ucx_del_procs_nofence(opal_common_ucx_del_proc_t *
115115
size_t my_rank, size_t max_disconnect, ucp_worker_h worker);
116116
OPAL_DECLSPEC void opal_common_ucx_mca_var_register(const mca_base_component_t *component);
117117

118+
119+
/**
120+
* Load an integer value of \c size bytes from \c ptr and cast it to uint64_t.
121+
*/
122+
static inline
123+
uint64_t opal_common_ucx_load_uint64(void *ptr, size_t size)
124+
{
125+
if (sizeof(uint8_t) == size) {
126+
return *(uint8_t*)ptr;
127+
} else if (sizeof(uint16_t) == size) {
128+
return *(uint16_t*)ptr;
129+
} else if (sizeof(uint32_t) == size) {
130+
return *(uint32_t*)ptr;
131+
} else {
132+
return *(uint64_t*)ptr;
133+
}
134+
}
135+
136+
/**
137+
* Cast and store a uint64_t value to a value of \c size bytes pointed to by \c ptr.
138+
*/
139+
static inline
140+
void opal_common_ucx_store_uint64(uint64_t value, void *ptr, size_t size)
141+
{
142+
if (sizeof(uint8_t) == size) {
143+
*(uint8_t*)ptr = value;
144+
} else if (sizeof(uint16_t) == size) {
145+
*(uint16_t*)ptr = value;
146+
} else if (sizeof(uint32_t) == size) {
147+
*(uint32_t*)ptr = value;
148+
} else {
149+
*(uint64_t*)ptr = value;
150+
}
151+
}
152+
153+
118154
static inline
119155
ucs_status_t opal_common_ucx_request_status(ucs_status_ptr_t request)
120156
{
@@ -206,13 +242,7 @@ int opal_common_ucx_atomic_cswap(ucp_ep_h ep, uint64_t compare,
206242
uint64_t remote_addr, ucp_rkey_h rkey,
207243
ucp_worker_h worker)
208244
{
209-
if (op_size == sizeof(uint64_t)) {
210-
*(uint64_t*)result = value;
211-
} else {
212-
assert(op_size == sizeof(uint32_t));
213-
*(uint32_t*)result = value;
214-
}
215-
245+
opal_common_ucx_store_uint64(value, result, op_size);
216246
return opal_common_ucx_atomic_fetch(ep, UCP_ATOMIC_FETCH_OP_CSWAP, compare, result,
217247
op_size, remote_addr, rkey, worker);
218248
}
@@ -224,13 +254,7 @@ ucs_status_ptr_t opal_common_ucx_atomic_cswap_nb(ucp_ep_h ep, uint64_t compare,
224254
ucp_send_callback_t req_handler,
225255
ucp_worker_h worker)
226256
{
227-
if (op_size == sizeof(uint64_t)) {
228-
*(uint64_t*)result = value;
229-
} else {
230-
assert(op_size == sizeof(uint32_t));
231-
*(uint32_t*)result = value;
232-
}
233-
257+
opal_common_ucx_store_uint64(value, result, op_size);
234258
return opal_common_ucx_atomic_fetch_nb(ep, UCP_ATOMIC_FETCH_OP_CSWAP, compare, result,
235259
op_size, remote_addr, rkey, req_handler, worker);
236260
}

0 commit comments

Comments
 (0)