Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 26 additions & 1 deletion ompi/mca/osc/ucx/osc_ucx_comm.c
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,15 @@
#include "osc_ucx.h"
#include "osc_ucx_request.h"


#define CHECK_VALID_RKEY(_module, _target, _count) \
if (!((_module)->win_info_array[_target]).rkey_init && ((_count) > 0)) { \
opal_output_verbose(1, ompi_osc_base_framework.framework_output, \
"%s:%d: window with non-zero length does not have an rkey\n", \
__FILE__, __LINE__); \
return OMPI_ERROR; \
}

typedef struct ucx_iovec {
void *addr;
size_t len;
Expand Down Expand Up @@ -337,7 +346,7 @@ static inline int get_dynamic_win_info(uint64_t remote_addr, ompi_osc_ucx_module

if ((module->win_info_array[target]).rkey_init == true) {
ucp_rkey_destroy((module->win_info_array[target]).rkey);
(module->win_info_array[target]).rkey_init == false;
(module->win_info_array[target]).rkey_init = false;
}

status = ucp_get_nbi(ep, (void *)temp_buf, len, remote_state_addr, state_rkey);
Expand Down Expand Up @@ -404,6 +413,12 @@ int ompi_osc_ucx_put(const void *origin_addr, int origin_count, struct ompi_data
}
}

CHECK_VALID_RKEY(module, target, target_count);

if (!target_count) {
return OMPI_SUCCESS;
}

rkey = (module->win_info_array[target]).rkey;

ompi_datatype_get_true_extent(origin_dt, &origin_lb, &origin_extent);
Expand Down Expand Up @@ -460,6 +475,12 @@ int ompi_osc_ucx_get(void *origin_addr, int origin_count,
}
}

CHECK_VALID_RKEY(module, target, target_count);

if (!target_count) {
return OMPI_SUCCESS;
}

rkey = (module->win_info_array[target]).rkey;

ompi_datatype_get_true_extent(origin_dt, &origin_lb, &origin_extent);
Expand Down Expand Up @@ -900,6 +921,8 @@ int ompi_osc_ucx_rput(const void *origin_addr, int origin_count,
}
}

CHECK_VALID_RKEY(module, target, target_count);

rkey = (module->win_info_array[target]).rkey;

OMPI_OSC_UCX_REQUEST_ALLOC(win, ucx_req);
Expand Down Expand Up @@ -963,6 +986,8 @@ int ompi_osc_ucx_rget(void *origin_addr, int origin_count,
}
}

CHECK_VALID_RKEY(module, target, target_count);

rkey = (module->win_info_array[target]).rkey;

OMPI_OSC_UCX_REQUEST_ALLOC(win, ucx_req);
Expand Down
44 changes: 31 additions & 13 deletions ompi/mca/osc/ucx/osc_ucx_component.c
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@
#include "osc_ucx.h"
#include "osc_ucx_request.h"

#define memcpy_off(_dst, _src, _len, _off) \
memcpy(((char*)(_dst)) + (_off), _src, _len); \
(_off) += (_len);

static int component_open(void);
static int component_register(void);
static int component_init(bool enable_progress_threads, bool enable_mpi_threads);
Expand Down Expand Up @@ -325,6 +329,8 @@ static int component_select(struct ompi_win_t *win, void **base, size_t size, in
int disps[comm_size];
int rkey_sizes[comm_size];
uint64_t zero = 0;
size_t info_offset;
uint64_t size_u64;

/* the osc/sm component is the exclusive provider for support for
* shared memory windows */
Expand Down Expand Up @@ -538,22 +544,27 @@ static int component_select(struct ompi_win_t *win, void **base, size_t size, in
goto error;
}

my_info_len = 2 * sizeof(uint64_t) + rkey_buffer_size + state_rkey_buffer_size;
size_u64 = (uint64_t)size;
my_info_len = 3 * sizeof(uint64_t) + rkey_buffer_size + state_rkey_buffer_size;
my_info = malloc(my_info_len);
if (my_info == NULL) {
ret = OMPI_ERR_TEMP_OUT_OF_RESOURCE;
goto error;
}

info_offset = 0;

if (flavor == MPI_WIN_FLAVOR_ALLOCATE || flavor == MPI_WIN_FLAVOR_CREATE) {
memcpy(my_info, base, sizeof(uint64_t));
memcpy_off(my_info, base, sizeof(uint64_t), info_offset);
} else {
memcpy(my_info, &zero, sizeof(uint64_t));
memcpy_off(my_info, &zero, sizeof(uint64_t), info_offset);
}
memcpy((void *)((char *)my_info + sizeof(uint64_t)), &state_base, sizeof(uint64_t));
memcpy((void *)((char *)my_info + 2 * sizeof(uint64_t)), rkey_buffer, rkey_buffer_size);
memcpy((void *)((char *)my_info + 2 * sizeof(uint64_t) + rkey_buffer_size),
state_rkey_buffer, state_rkey_buffer_size);
memcpy_off(my_info, &state_base, sizeof(uint64_t), info_offset);
memcpy_off(my_info, &size_u64, sizeof(uint64_t), info_offset);
memcpy_off(my_info, rkey_buffer, rkey_buffer_size, info_offset);
memcpy_off(my_info, state_rkey_buffer, state_rkey_buffer_size, info_offset);

assert(my_info_len == info_offset);

ret = allgather_len_and_info(my_info, (int)my_info_len, &recv_buf, disps, module->comm);
if (ret != OMPI_SUCCESS) {
Expand All @@ -569,15 +580,21 @@ static int component_select(struct ompi_win_t *win, void **base, size_t size, in

for (i = 0; i < comm_size; i++) {
ucp_ep_h ep = OSC_UCX_GET_EP(module->comm, i);
uint64_t dest_size;
assert(ep != NULL);

memcpy(&(module->win_info_array[i]).addr, &recv_buf[disps[i]], sizeof(uint64_t));
memcpy(&(module->state_info_array[i]).addr, &recv_buf[disps[i] + sizeof(uint64_t)],
sizeof(uint64_t));
info_offset = disps[i];

memcpy(&(module->win_info_array[i]).addr, &recv_buf[info_offset], sizeof(uint64_t));
info_offset += sizeof(uint64_t);
memcpy(&(module->state_info_array[i]).addr, &recv_buf[info_offset], sizeof(uint64_t));
info_offset += sizeof(uint64_t);
memcpy(&dest_size, &recv_buf[info_offset], sizeof(uint64_t));
info_offset += sizeof(uint64_t);

(module->win_info_array[i]).rkey_init = false;
if (size > 0 && (flavor == MPI_WIN_FLAVOR_ALLOCATE || flavor == MPI_WIN_FLAVOR_CREATE)) {
status = ucp_ep_rkey_unpack(ep, &(recv_buf[disps[i] + 2 * sizeof(uint64_t)]),
if (dest_size > 0 && (flavor == MPI_WIN_FLAVOR_ALLOCATE || flavor == MPI_WIN_FLAVOR_CREATE)) {
status = ucp_ep_rkey_unpack(ep, &recv_buf[info_offset],
&((module->win_info_array[i]).rkey));
if (status != UCS_OK) {
opal_output_verbose(1, ompi_osc_base_framework.framework_output,
Expand All @@ -586,10 +603,11 @@ static int component_select(struct ompi_win_t *win, void **base, size_t size, in
ret = OMPI_ERROR;
goto error;
}
info_offset += rkey_sizes[i];
(module->win_info_array[i]).rkey_init = true;
}

status = ucp_ep_rkey_unpack(ep, &(recv_buf[disps[i] + 2 * sizeof(uint64_t) + rkey_sizes[i]]),
status = ucp_ep_rkey_unpack(ep, &recv_buf[info_offset],
&((module->state_info_array[i]).rkey));
if (status != UCS_OK) {
opal_output_verbose(1, ompi_osc_base_framework.framework_output,
Expand Down