Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 15 additions & 5 deletions opal/mca/common/ucx/common_ucx.c
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,11 @@ void opal_common_ucx_mca_proc_added(void)
#endif
}

OPAL_DECLSPEC int opal_common_ucx_mca_pmix_fence_nb(int *fenced)
{
return opal_pmix.fence_nb(NULL, 0, opal_common_ucx_mca_fence_complete_cb, (void *)fenced);
}

OPAL_DECLSPEC int opal_common_ucx_mca_pmix_fence(ucp_worker_h worker)
{
volatile int fenced = 0;
Expand Down Expand Up @@ -181,9 +186,8 @@ static void opal_common_ucx_wait_all_requests(void **reqs, int count, ucp_worker
}
}

OPAL_DECLSPEC int opal_common_ucx_del_procs(opal_common_ucx_del_proc_t *procs, size_t count,
size_t my_rank, size_t max_disconnect, ucp_worker_h worker)
{
OPAL_DECLSPEC int opal_common_ucx_del_procs_nofence(opal_common_ucx_del_proc_t *procs, size_t count,
size_t my_rank, size_t max_disconnect, ucp_worker_h worker) {
size_t num_reqs;
size_t max_reqs;
void *dreq, **dreqs;
Expand Down Expand Up @@ -230,7 +234,13 @@ OPAL_DECLSPEC int opal_common_ucx_del_procs(opal_common_ucx_del_proc_t *procs, s
opal_common_ucx_wait_all_requests(dreqs, num_reqs, worker);
free(dreqs);

opal_common_ucx_mca_pmix_fence(worker);

return OPAL_SUCCESS;
}

OPAL_DECLSPEC int opal_common_ucx_del_procs(opal_common_ucx_del_proc_t *procs, size_t count,
size_t my_rank, size_t max_disconnect, ucp_worker_h worker)
{
opal_common_ucx_del_procs_nofence(procs, count, my_rank, max_disconnect, worker);

return opal_common_ucx_mca_pmix_fence(worker);
}
5 changes: 4 additions & 1 deletion opal/mca/common/ucx/common_ucx.h
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,12 @@ OPAL_DECLSPEC void opal_common_ucx_mca_deregister(void);
OPAL_DECLSPEC void opal_common_ucx_mca_proc_added(void);
OPAL_DECLSPEC void opal_common_ucx_empty_complete_cb(void *request, ucs_status_t status);
OPAL_DECLSPEC int opal_common_ucx_mca_pmix_fence(ucp_worker_h worker);
OPAL_DECLSPEC void opal_common_ucx_mca_var_register(const mca_base_component_t *component);
OPAL_DECLSPEC int opal_common_ucx_mca_pmix_fence_nb(int *fenced);
OPAL_DECLSPEC int opal_common_ucx_del_procs(opal_common_ucx_del_proc_t *procs, size_t count,
size_t my_rank, size_t max_disconnect, ucp_worker_h worker);
OPAL_DECLSPEC int opal_common_ucx_del_procs_nofence(opal_common_ucx_del_proc_t *procs, size_t count,
size_t my_rank, size_t max_disconnect, ucp_worker_h worker);
OPAL_DECLSPEC void opal_common_ucx_mca_var_register(const mca_base_component_t *component);

static inline
ucs_status_t opal_common_ucx_request_status(ucs_status_ptr_t request)
Expand Down
2 changes: 1 addition & 1 deletion oshmem/mca/atomic/ucx/atomic_ucx_cswap.c
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ int mca_atomic_ucx_cswap(shmem_ctx_t ctx,
assert(NULL != prev);

*prev = value;
ucx_mkey = mca_spml_ucx_get_mkey(ucx_ctx, pe, target, (void *)&rva, mca_spml_self);
ucx_mkey = mca_spml_ucx_get_mkey(ctx, pe, target, (void *)&rva, mca_spml_self);
status_ptr = ucp_atomic_fetch_nb(ucx_ctx->ucp_peers[pe].ucp_conn,
UCP_ATOMIC_FETCH_OP_CSWAP, cond, prev, size,
rva, ucx_mkey->rkey,
Expand Down
4 changes: 2 additions & 2 deletions oshmem/mca/atomic/ucx/atomic_ucx_module.c
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ int mca_atomic_ucx_op(shmem_ctx_t ctx,

assert((8 == size) || (4 == size));

ucx_mkey = mca_spml_ucx_get_mkey(ucx_ctx, pe, target, (void *)&rva, mca_spml_self);
ucx_mkey = mca_spml_ucx_get_mkey(ctx, pe, target, (void *)&rva, mca_spml_self);
status = ucp_atomic_post(ucx_ctx->ucp_peers[pe].ucp_conn,
op, value, size, rva,
ucx_mkey->rkey);
Expand All @@ -70,7 +70,7 @@ int mca_atomic_ucx_fop(shmem_ctx_t ctx,

assert((8 == size) || (4 == size));

ucx_mkey = mca_spml_ucx_get_mkey(ucx_ctx, pe, target, (void *)&rva, mca_spml_self);
ucx_mkey = mca_spml_ucx_get_mkey(ctx, pe, target, (void *)&rva, mca_spml_self);
status_ptr = ucp_atomic_fetch_nb(ucx_ctx->ucp_peers[pe].ucp_conn,
op, value, prev, size,
rva, ucx_mkey->rkey,
Expand Down
8 changes: 5 additions & 3 deletions oshmem/mca/memheap/base/base.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ void memheap_oob_destruct(void);
OSHMEM_DECLSPEC int mca_memheap_base_is_symmetric_addr(const void* va);
OSHMEM_DECLSPEC sshmem_mkey_t *mca_memheap_base_get_mkey(void* va,
int tr_id);
OSHMEM_DECLSPEC sshmem_mkey_t * mca_memheap_base_get_cached_mkey_slow(map_segment_t *s,
OSHMEM_DECLSPEC sshmem_mkey_t * mca_memheap_base_get_cached_mkey_slow(shmem_ctx_t ctx,
map_segment_t *s,
int pe,
void* va,
int btl_id,
Expand Down Expand Up @@ -243,7 +244,8 @@ static inline map_segment_t *memheap_find_va(void* va)
return s;
}

static inline sshmem_mkey_t *mca_memheap_base_get_cached_mkey(int pe,
static inline sshmem_mkey_t *mca_memheap_base_get_cached_mkey(shmem_ctx_t ctx,
int pe,
void* va,
int btl_id,
void** rva)
Expand Down Expand Up @@ -273,7 +275,7 @@ static inline sshmem_mkey_t *mca_memheap_base_get_cached_mkey(int pe,
return mkey;
}

return mca_memheap_base_get_cached_mkey_slow(s, pe, va, btl_id, rva);
return mca_memheap_base_get_cached_mkey_slow(ctx, s, pe, va, btl_id, rva);
}

static inline int mca_memheap_base_num_transports(void)
Expand Down
21 changes: 12 additions & 9 deletions oshmem/mca/memheap/base/memheap_base_mkey.c
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ struct oob_comm {
oob_comm_request_t req_pool[MEMHEAP_RECV_REQS_MAX];
opal_list_t req_list;
int is_inited;
shmem_ctx_t ctx;
};

mca_memheap_map_t* memheap_map = NULL;
Expand All @@ -66,7 +67,7 @@ static int send_buffer(int pe, opal_buffer_t *msg);
static int oshmem_mkey_recv_cb(void);

/* pickup list of rkeys and remote va */
static int memheap_oob_get_mkeys(int pe,
static int memheap_oob_get_mkeys(shmem_ctx_t ctx, int pe,
uint32_t va_seg_num,
sshmem_mkey_t *mkey);

Expand Down Expand Up @@ -142,7 +143,7 @@ static void memheap_attach_segment(sshmem_mkey_t *mkey, int tr_id)
}


static void unpack_remote_mkeys(opal_buffer_t *msg, int remote_pe)
static void unpack_remote_mkeys(shmem_ctx_t ctx, opal_buffer_t *msg, int remote_pe)
{
int32_t cnt;
int32_t n;
Expand Down Expand Up @@ -182,7 +183,7 @@ static void unpack_remote_mkeys(opal_buffer_t *msg, int remote_pe)
} else {
memheap_oob.mkeys[tr_id].u.key = MAP_SEGMENT_SHM_INVALID;
}
MCA_SPML_CALL(rmkey_unpack(&memheap_oob.mkeys[tr_id], memheap_oob.segno, remote_pe, tr_id));
MCA_SPML_CALL(rmkey_unpack(ctx, &memheap_oob.mkeys[tr_id], memheap_oob.segno, remote_pe, tr_id));
}

MEMHEAP_VERBOSE(5,
Expand Down Expand Up @@ -242,7 +243,7 @@ static void do_recv(int source_pe, opal_buffer_t* buffer)
case MEMHEAP_RKEY_RESP:
MEMHEAP_VERBOSE(5, "*** RKEY RESP");
OPAL_THREAD_LOCK(&memheap_oob.lck);
unpack_remote_mkeys(buffer, source_pe);
unpack_remote_mkeys(memheap_oob.ctx, buffer, source_pe);
memheap_oob.mkeys_rcvd = MEMHEAP_RKEY_RESP;
opal_condition_broadcast(&memheap_oob.cond);
OPAL_THREAD_UNLOCK(&memheap_oob.lck);
Expand Down Expand Up @@ -455,14 +456,14 @@ static int send_buffer(int pe, opal_buffer_t *msg)
return rc;
}

static int memheap_oob_get_mkeys(int pe, uint32_t seg, sshmem_mkey_t *mkeys)
static int memheap_oob_get_mkeys(shmem_ctx_t ctx, int pe, uint32_t seg, sshmem_mkey_t *mkeys)
{
opal_buffer_t *msg;
uint8_t cmd;
int i;
int rc;

if (OSHMEM_SUCCESS == MCA_SPML_CALL(oob_get_mkeys(pe, seg, mkeys))) {
if (OSHMEM_SUCCESS == MCA_SPML_CALL(oob_get_mkeys(ctx, pe, seg, mkeys))) {
for (i = 0; i < memheap_map->num_transports; i++) {
MEMHEAP_VERBOSE(5,
"MKEY CALCULATED BY LOCAL SPML: pe: %d tr_id: %d %s",
Expand All @@ -478,6 +479,7 @@ static int memheap_oob_get_mkeys(int pe, uint32_t seg, sshmem_mkey_t *mkeys)
memheap_oob.mkeys = mkeys;
memheap_oob.segno = seg;
memheap_oob.mkeys_rcvd = 0;
memheap_oob.ctx = ctx;

msg = OBJ_NEW(opal_buffer_t);
if (!msg) {
Expand Down Expand Up @@ -645,7 +647,7 @@ void mca_memheap_modex_recv_all(void)
}
memheap_oob.mkeys = s->mkeys_cache[i];
memheap_oob.segno = j;
unpack_remote_mkeys(msg, i);
unpack_remote_mkeys(oshmem_ctx_default, msg, i);
}
}

Expand Down Expand Up @@ -674,7 +676,8 @@ void mca_memheap_modex_recv_all(void)
}
}

sshmem_mkey_t * mca_memheap_base_get_cached_mkey_slow(map_segment_t *s,
sshmem_mkey_t * mca_memheap_base_get_cached_mkey_slow(shmem_ctx_t ctx,
map_segment_t *s,
int pe,
void* va,
int btl_id,
Expand All @@ -692,7 +695,7 @@ sshmem_mkey_t * mca_memheap_base_get_cached_mkey_slow(map_segment_t *s,
if (!s->mkeys_cache[pe])
return NULL ;

rc = memheap_oob_get_mkeys(pe,
rc = memheap_oob_get_mkeys(ctx, pe,
s - memheap_map->mem_segs,
s->mkeys_cache[pe]);
if (OSHMEM_SUCCESS != rc)
Expand Down
5 changes: 3 additions & 2 deletions oshmem/mca/spml/base/base.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,11 +72,12 @@ OSHMEM_DECLSPEC int mca_spml_base_test(void* addr,
void* value,
int datatype,
int *out_value);
OSHMEM_DECLSPEC int mca_spml_base_oob_get_mkeys(int pe,
OSHMEM_DECLSPEC int mca_spml_base_oob_get_mkeys(shmem_ctx_t ctx,
int pe,
uint32_t seg,
sshmem_mkey_t *mkeys);

OSHMEM_DECLSPEC void mca_spml_base_rmkey_unpack(sshmem_mkey_t *mkey, uint32_t seg, int pe, int tr_id);
OSHMEM_DECLSPEC void mca_spml_base_rmkey_unpack(shmem_ctx_t ctx, sshmem_mkey_t *mkey, uint32_t seg, int pe, int tr_id);
OSHMEM_DECLSPEC void mca_spml_base_rmkey_free(sshmem_mkey_t *mkey);
OSHMEM_DECLSPEC void *mca_spml_base_rmkey_ptr(const void *dst_addr, sshmem_mkey_t *mkey, int pe);

Expand Down
4 changes: 2 additions & 2 deletions oshmem/mca/spml/base/spml_base.c
Original file line number Diff line number Diff line change
Expand Up @@ -247,12 +247,12 @@ int mca_spml_base_wait_nb(void* handle)
return OSHMEM_SUCCESS;
}

int mca_spml_base_oob_get_mkeys(int pe, uint32_t segno, sshmem_mkey_t *mkeys)
int mca_spml_base_oob_get_mkeys(shmem_ctx_t ctx, int pe, uint32_t segno, sshmem_mkey_t *mkeys)
{
return OSHMEM_ERROR;
}

void mca_spml_base_rmkey_unpack(sshmem_mkey_t *mkey, uint32_t segno, int pe, int tr_id)
void mca_spml_base_rmkey_unpack(shmem_ctx_t ctx, sshmem_mkey_t *mkey, uint32_t segno, int pe, int tr_id)
{
}

Expand Down
4 changes: 2 additions & 2 deletions oshmem/mca/spml/spml.h
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ typedef int (*mca_spml_base_module_test_fn_t)(void* addr,
*
* @param mkey remote mkey
*/
typedef void (*mca_spml_base_module_mkey_unpack_fn_t)(sshmem_mkey_t *, uint32_t segno, int remote_pe, int tr_id);
typedef void (*mca_spml_base_module_mkey_unpack_fn_t)(shmem_ctx_t ctx, sshmem_mkey_t *, uint32_t segno, int remote_pe, int tr_id);

/**
* If possible, get a pointer to the remote memory described by the mkey
Expand Down Expand Up @@ -180,7 +180,7 @@ typedef int (*mca_spml_base_module_deregister_fn_t)(sshmem_mkey_t *mkeys);
*
* @return OSHMEM_SUCCSESS if keys are found
*/
typedef int (*mca_spml_base_module_oob_get_mkeys_fn_t)(int pe,
typedef int (*mca_spml_base_module_oob_get_mkeys_fn_t)(shmem_ctx_t ctx, int pe,
uint32_t seg,
sshmem_mkey_t *mkeys);

Expand Down
Loading