Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion oshmem/mca/atomic/ucx/atomic_ucx_cswap.c
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ int mca_atomic_ucx_cswap(shmem_ctx_t ctx,
assert(NULL != prev);

*prev = value;
ucx_mkey = mca_spml_ucx_get_mkey(ucx_ctx, pe, target, (void *)&rva, mca_spml_self);
ucx_mkey = mca_spml_ucx_get_mkey(ctx, pe, target, (void *)&rva, mca_spml_self);
status_ptr = ucp_atomic_fetch_nb(ucx_ctx->ucp_peers[pe].ucp_conn,
UCP_ATOMIC_FETCH_OP_CSWAP, cond, prev, size,
rva, ucx_mkey->rkey,
Expand Down
4 changes: 2 additions & 2 deletions oshmem/mca/atomic/ucx/atomic_ucx_module.c
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ int mca_atomic_ucx_op(shmem_ctx_t ctx,

assert((8 == size) || (4 == size));

ucx_mkey = mca_spml_ucx_get_mkey(ucx_ctx, pe, target, (void *)&rva, mca_spml_self);
ucx_mkey = mca_spml_ucx_get_mkey(ctx, pe, target, (void *)&rva, mca_spml_self);
status = ucp_atomic_post(ucx_ctx->ucp_peers[pe].ucp_conn,
op, value, size, rva,
ucx_mkey->rkey);
Expand All @@ -70,7 +70,7 @@ int mca_atomic_ucx_fop(shmem_ctx_t ctx,

assert((8 == size) || (4 == size));

ucx_mkey = mca_spml_ucx_get_mkey(ucx_ctx, pe, target, (void *)&rva, mca_spml_self);
ucx_mkey = mca_spml_ucx_get_mkey(ctx, pe, target, (void *)&rva, mca_spml_self);
status_ptr = ucp_atomic_fetch_nb(ucx_ctx->ucp_peers[pe].ucp_conn,
op, value, prev, size,
rva, ucx_mkey->rkey,
Expand Down
8 changes: 5 additions & 3 deletions oshmem/mca/memheap/base/base.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ void memheap_oob_destruct(void);
OSHMEM_DECLSPEC int mca_memheap_base_is_symmetric_addr(const void* va);
OSHMEM_DECLSPEC sshmem_mkey_t *mca_memheap_base_get_mkey(void* va,
int tr_id);
OSHMEM_DECLSPEC sshmem_mkey_t * mca_memheap_base_get_cached_mkey_slow(map_segment_t *s,
OSHMEM_DECLSPEC sshmem_mkey_t * mca_memheap_base_get_cached_mkey_slow(shmem_ctx_t ctx,
map_segment_t *s,
int pe,
void* va,
int btl_id,
Expand Down Expand Up @@ -243,7 +244,8 @@ static inline map_segment_t *memheap_find_va(void* va)
return s;
}

static inline sshmem_mkey_t *mca_memheap_base_get_cached_mkey(int pe,
static inline sshmem_mkey_t *mca_memheap_base_get_cached_mkey(shmem_ctx_t ctx,
int pe,
void* va,
int btl_id,
void** rva)
Expand Down Expand Up @@ -273,7 +275,7 @@ static inline sshmem_mkey_t *mca_memheap_base_get_cached_mkey(int pe,
return mkey;
}

return mca_memheap_base_get_cached_mkey_slow(s, pe, va, btl_id, rva);
return mca_memheap_base_get_cached_mkey_slow(ctx, s, pe, va, btl_id, rva);
}

static inline int mca_memheap_base_num_transports(void)
Expand Down
21 changes: 12 additions & 9 deletions oshmem/mca/memheap/base/memheap_base_mkey.c
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ struct oob_comm {
oob_comm_request_t req_pool[MEMHEAP_RECV_REQS_MAX];
opal_list_t req_list;
int is_inited;
shmem_ctx_t ctx;
};

mca_memheap_map_t* memheap_map = NULL;
Expand All @@ -66,7 +67,7 @@ static int send_buffer(int pe, opal_buffer_t *msg);
static int oshmem_mkey_recv_cb(void);

/* pickup list of rkeys and remote va */
static int memheap_oob_get_mkeys(int pe,
static int memheap_oob_get_mkeys(shmem_ctx_t ctx, int pe,
uint32_t va_seg_num,
sshmem_mkey_t *mkey);

Expand Down Expand Up @@ -142,7 +143,7 @@ static void memheap_attach_segment(sshmem_mkey_t *mkey, int tr_id)
}


static void unpack_remote_mkeys(opal_buffer_t *msg, int remote_pe)
static void unpack_remote_mkeys(shmem_ctx_t ctx, opal_buffer_t *msg, int remote_pe)
{
int32_t cnt;
int32_t n;
Expand Down Expand Up @@ -182,7 +183,7 @@ static void unpack_remote_mkeys(opal_buffer_t *msg, int remote_pe)
} else {
memheap_oob.mkeys[tr_id].u.key = MAP_SEGMENT_SHM_INVALID;
}
MCA_SPML_CALL(rmkey_unpack(&memheap_oob.mkeys[tr_id], memheap_oob.segno, remote_pe, tr_id));
MCA_SPML_CALL(rmkey_unpack(ctx, &memheap_oob.mkeys[tr_id], memheap_oob.segno, remote_pe, tr_id));
}

MEMHEAP_VERBOSE(5,
Expand Down Expand Up @@ -242,7 +243,7 @@ static void do_recv(int source_pe, opal_buffer_t* buffer)
case MEMHEAP_RKEY_RESP:
MEMHEAP_VERBOSE(5, "*** RKEY RESP");
OPAL_THREAD_LOCK(&memheap_oob.lck);
unpack_remote_mkeys(buffer, source_pe);
unpack_remote_mkeys(memheap_oob.ctx, buffer, source_pe);
memheap_oob.mkeys_rcvd = MEMHEAP_RKEY_RESP;
opal_condition_broadcast(&memheap_oob.cond);
OPAL_THREAD_UNLOCK(&memheap_oob.lck);
Expand Down Expand Up @@ -455,14 +456,14 @@ static int send_buffer(int pe, opal_buffer_t *msg)
return rc;
}

static int memheap_oob_get_mkeys(int pe, uint32_t seg, sshmem_mkey_t *mkeys)
static int memheap_oob_get_mkeys(shmem_ctx_t ctx, int pe, uint32_t seg, sshmem_mkey_t *mkeys)
{
opal_buffer_t *msg;
uint8_t cmd;
int i;
int rc;

if (OSHMEM_SUCCESS == MCA_SPML_CALL(oob_get_mkeys(pe, seg, mkeys))) {
if (OSHMEM_SUCCESS == MCA_SPML_CALL(oob_get_mkeys(ctx, pe, seg, mkeys))) {
for (i = 0; i < memheap_map->num_transports; i++) {
MEMHEAP_VERBOSE(5,
"MKEY CALCULATED BY LOCAL SPML: pe: %d tr_id: %d %s",
Expand All @@ -478,6 +479,7 @@ static int memheap_oob_get_mkeys(int pe, uint32_t seg, sshmem_mkey_t *mkeys)
memheap_oob.mkeys = mkeys;
memheap_oob.segno = seg;
memheap_oob.mkeys_rcvd = 0;
memheap_oob.ctx = ctx;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can it be overwritten if several keys are requested in a row?


msg = OBJ_NEW(opal_buffer_t);
if (!msg) {
Expand Down Expand Up @@ -645,7 +647,7 @@ void mca_memheap_modex_recv_all(void)
}
memheap_oob.mkeys = s->mkeys_cache[i];
memheap_oob.segno = j;
unpack_remote_mkeys(msg, i);
unpack_remote_mkeys(oshmem_ctx_default, msg, i);
}
}

Expand Down Expand Up @@ -674,7 +676,8 @@ void mca_memheap_modex_recv_all(void)
}
}

sshmem_mkey_t * mca_memheap_base_get_cached_mkey_slow(map_segment_t *s,
sshmem_mkey_t * mca_memheap_base_get_cached_mkey_slow(shmem_ctx_t ctx,
map_segment_t *s,
int pe,
void* va,
int btl_id,
Expand All @@ -692,7 +695,7 @@ sshmem_mkey_t * mca_memheap_base_get_cached_mkey_slow(map_segment_t *s,
if (!s->mkeys_cache[pe])
return NULL ;

rc = memheap_oob_get_mkeys(pe,
rc = memheap_oob_get_mkeys(ctx, pe,
s - memheap_map->mem_segs,
s->mkeys_cache[pe]);
if (OSHMEM_SUCCESS != rc)
Expand Down
5 changes: 3 additions & 2 deletions oshmem/mca/spml/base/base.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,11 +72,12 @@ OSHMEM_DECLSPEC int mca_spml_base_test(void* addr,
void* value,
int datatype,
int *out_value);
OSHMEM_DECLSPEC int mca_spml_base_oob_get_mkeys(int pe,
OSHMEM_DECLSPEC int mca_spml_base_oob_get_mkeys(shmem_ctx_t ctx,
int pe,
uint32_t seg,
sshmem_mkey_t *mkeys);

OSHMEM_DECLSPEC void mca_spml_base_rmkey_unpack(sshmem_mkey_t *mkey, uint32_t seg, int pe, int tr_id);
OSHMEM_DECLSPEC void mca_spml_base_rmkey_unpack(shmem_ctx_t ctx, sshmem_mkey_t *mkey, uint32_t seg, int pe, int tr_id);
OSHMEM_DECLSPEC void mca_spml_base_rmkey_free(sshmem_mkey_t *mkey);
OSHMEM_DECLSPEC void *mca_spml_base_rmkey_ptr(const void *dst_addr, sshmem_mkey_t *mkey, int pe);

Expand Down
4 changes: 2 additions & 2 deletions oshmem/mca/spml/base/spml_base.c
Original file line number Diff line number Diff line change
Expand Up @@ -247,12 +247,12 @@ int mca_spml_base_wait_nb(void* handle)
return OSHMEM_SUCCESS;
}

int mca_spml_base_oob_get_mkeys(int pe, uint32_t segno, sshmem_mkey_t *mkeys)
int mca_spml_base_oob_get_mkeys(shmem_ctx_t ctx, int pe, uint32_t segno, sshmem_mkey_t *mkeys)
{
return OSHMEM_ERROR;
}

void mca_spml_base_rmkey_unpack(sshmem_mkey_t *mkey, uint32_t segno, int pe, int tr_id)
void mca_spml_base_rmkey_unpack(shmem_ctx_t ctx, sshmem_mkey_t *mkey, uint32_t segno, int pe, int tr_id)
{
}

Expand Down
10 changes: 5 additions & 5 deletions oshmem/mca/spml/ikrit/spml_ikrit.c
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ int mca_spml_ikrit_put_simple(void* dst_addr,
void* src_addr,
int dst);

static void mca_spml_ikrit_cache_mkeys(sshmem_mkey_t *, uint32_t seg, int remote_pe, int tr_id);
static void mca_spml_ikrit_cache_mkeys(shmem_ctx_t ctx, sshmem_mkey_t *, uint32_t seg, int remote_pe, int tr_id);

static mxm_mem_key_t *mca_spml_ikrit_get_mkey_slow(int pe, void *va, int ptl_id, void **rva);

Expand Down Expand Up @@ -187,7 +187,7 @@ mca_spml_ikrit_t mca_spml_ikrit = {
mca_spml_ikrit_get_mkey_slow
};

static void mca_spml_ikrit_cache_mkeys(sshmem_mkey_t *mkey, uint32_t seg, int dst_pe, int tr_id)
static void mca_spml_ikrit_cache_mkeys(shmem_ctx_t ctx, sshmem_mkey_t *mkey, uint32_t seg, int dst_pe, int tr_id)
{
mxm_peer_t *peer;

Expand Down Expand Up @@ -506,7 +506,7 @@ sshmem_mkey_t *mca_spml_ikrit_register(void* addr,
my_rank, i, addr, (unsigned long long)size,
mca_spml_base_mkey2str(&mkeys[i]));

mca_spml_ikrit_cache_mkeys(&mkeys[i], memheap_find_segnum(addr), my_rank, i);
mca_spml_ikrit_cache_mkeys(NULL, &mkeys[i], memheap_find_segnum(addr), my_rank, i);
}
*count = MXM_PTL_LAST;

Expand Down Expand Up @@ -550,7 +550,7 @@ int mca_spml_ikrit_deregister(sshmem_mkey_t *mkeys)

}

int mca_spml_ikrit_oob_get_mkeys(int pe, uint32_t seg, sshmem_mkey_t *mkeys)
int mca_spml_ikrit_oob_get_mkeys(shmem_ctx_t ctx, int pe, uint32_t seg, sshmem_mkey_t *mkeys)
{
int ptl;

Expand All @@ -569,7 +569,7 @@ int mca_spml_ikrit_oob_get_mkeys(int pe, uint32_t seg, sshmem_mkey_t *mkeys)
mkeys[ptl].len = 0;
mkeys[ptl].va_base = mca_memheap_seg2base_va(seg);
mkeys[ptl].u.key = MAP_SEGMENT_SHM_INVALID;
mca_spml_ikrit_cache_mkeys(&mkeys[ptl], seg, pe, ptl);
mca_spml_ikrit_cache_mkeys(NULL, &mkeys[ptl], seg, pe, ptl);
return OSHMEM_SUCCESS;
}

Expand Down
2 changes: 1 addition & 1 deletion oshmem/mca/spml/ikrit/spml_ikrit.h
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ extern sshmem_mkey_t *mca_spml_ikrit_register(void* addr,
uint64_t shmid,
int *count);
extern int mca_spml_ikrit_deregister(sshmem_mkey_t *mkeys);
extern int mca_spml_ikrit_oob_get_mkeys(int pe,
extern int mca_spml_ikrit_oob_get_mkeys(shmem_ctx_t ctx, int pe,
uint32_t segno,
sshmem_mkey_t *mkeys);

Expand Down
4 changes: 2 additions & 2 deletions oshmem/mca/spml/spml.h
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ typedef int (*mca_spml_base_module_test_fn_t)(void* addr,
*
* @param mkey remote mkey
*/
typedef void (*mca_spml_base_module_mkey_unpack_fn_t)(sshmem_mkey_t *, uint32_t segno, int remote_pe, int tr_id);
typedef void (*mca_spml_base_module_mkey_unpack_fn_t)(shmem_ctx_t ctx, sshmem_mkey_t *, uint32_t segno, int remote_pe, int tr_id);

/**
* If possible, get a pointer to the remote memory described by the mkey
Expand Down Expand Up @@ -180,7 +180,7 @@ typedef int (*mca_spml_base_module_deregister_fn_t)(sshmem_mkey_t *mkeys);
*
* @return OSHMEM_SUCCSESS if keys are found
*/
typedef int (*mca_spml_base_module_oob_get_mkeys_fn_t)(int pe,
typedef int (*mca_spml_base_module_oob_get_mkeys_fn_t)(shmem_ctx_t ctx, int pe,
uint32_t seg,
sshmem_mkey_t *mkeys);

Expand Down
33 changes: 13 additions & 20 deletions oshmem/mca/spml/ucx/spml_ucx.c
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
#endif

static
spml_ucx_mkey_t * mca_spml_ucx_get_mkey_slow(int pe, void *va, void **rva);
spml_ucx_mkey_t * mca_spml_ucx_get_mkey_slow(shmem_ctx_t ctx, int pe, void *va, void **rva);

mca_spml_ucx_t mca_spml_ucx = {
.super = {
Expand Down Expand Up @@ -309,11 +309,11 @@ int mca_spml_ucx_add_procs(ompi_proc_t** procs, size_t nprocs)


static
spml_ucx_mkey_t * mca_spml_ucx_get_mkey_slow(int pe, void *va, void **rva)
spml_ucx_mkey_t * mca_spml_ucx_get_mkey_slow(shmem_ctx_t ctx, int pe, void *va, void **rva)
{
sshmem_mkey_t *r_mkey;

r_mkey = mca_memheap_base_get_cached_mkey(pe, va, 0, rva);
r_mkey = mca_memheap_base_get_cached_mkey(ctx, pe, va, 0, rva);
if (OPAL_UNLIKELY(!r_mkey)) {
SPML_UCX_ERROR("pe=%d: %p is not address of symmetric variable",
pe, va);
Expand Down Expand Up @@ -351,31 +351,24 @@ void *mca_spml_ucx_rmkey_ptr(const void *dst_addr, sshmem_mkey_t *mkey, int pe)
#endif
}

static void mca_spml_ucx_cache_mkey(mca_spml_ucx_ctx_t *ucx_ctx, sshmem_mkey_t *mkey, uint32_t segno, int dst_pe)
{
ucp_peer_t *peer;

peer = &(ucx_ctx->ucp_peers[dst_pe]);
mkey_segment_init(&peer->mkeys[segno].super, mkey, segno);
}

void mca_spml_ucx_rmkey_unpack(sshmem_mkey_t *mkey, uint32_t segno, int pe, int tr_id)
void mca_spml_ucx_rmkey_unpack(shmem_ctx_t ctx, sshmem_mkey_t *mkey, uint32_t segno, int pe, int tr_id)
{
spml_ucx_mkey_t *ucx_mkey;
mca_spml_ucx_ctx_t *ucx_ctx = (mca_spml_ucx_ctx_t *)ctx;
ucs_status_t err;

ucx_mkey = &mca_spml_ucx_ctx_default.ucp_peers[pe].mkeys[segno].key;
ucx_mkey = &ucx_ctx->ucp_peers[pe].mkeys[segno].key;

err = ucp_ep_rkey_unpack(mca_spml_ucx_ctx_default.ucp_peers[pe].ucp_conn,
mkey->u.data,
err = ucp_ep_rkey_unpack(ucx_ctx->ucp_peers[pe].ucp_conn,
mkey->u.data,
&ucx_mkey->rkey);
if (UCS_OK != err) {
SPML_UCX_ERROR("failed to unpack rkey: %s", ucs_status_string(err));
goto error_fatal;
}

mkey->spml_context = ucx_mkey;
mca_spml_ucx_cache_mkey(&mca_spml_ucx_ctx_default, mkey, segno, pe);
mca_spml_ucx_cache_mkey(ucx_ctx, mkey, segno, pe);
return;

error_fatal:
Expand Down Expand Up @@ -636,7 +629,7 @@ int mca_spml_ucx_get(shmem_ctx_t ctx, void *src_addr, size_t size, void *dst_add
ucs_status_t status;
#endif

ucx_mkey = mca_spml_ucx_get_mkey(ucx_ctx, src, src_addr, &rva, &mca_spml_ucx);
ucx_mkey = mca_spml_ucx_get_mkey(ctx, src, src_addr, &rva, &mca_spml_ucx);
#if HAVE_DECL_UCP_GET_NB
request = ucp_get_nb(ucx_ctx->ucp_peers[src].ucp_conn, dst_addr, size,
(uint64_t)rva, ucx_mkey->rkey, opal_common_ucx_empty_complete_cb);
Expand All @@ -655,7 +648,7 @@ int mca_spml_ucx_get_nb(shmem_ctx_t ctx, void *src_addr, size_t size, void *dst_
spml_ucx_mkey_t *ucx_mkey;
mca_spml_ucx_ctx_t *ucx_ctx = (mca_spml_ucx_ctx_t *)ctx;

ucx_mkey = mca_spml_ucx_get_mkey(ucx_ctx, src, src_addr, &rva, &mca_spml_ucx);
ucx_mkey = mca_spml_ucx_get_mkey(ctx, src, src_addr, &rva, &mca_spml_ucx);
status = ucp_get_nbi(ucx_ctx->ucp_peers[src].ucp_conn, dst_addr, size,
(uint64_t)rva, ucx_mkey->rkey);

Expand All @@ -673,7 +666,7 @@ int mca_spml_ucx_put(shmem_ctx_t ctx, void* dst_addr, size_t size, void* src_add
ucs_status_t status;
#endif

ucx_mkey = mca_spml_ucx_get_mkey(ucx_ctx, dst, dst_addr, &rva, &mca_spml_ucx);
ucx_mkey = mca_spml_ucx_get_mkey(ctx, dst, dst_addr, &rva, &mca_spml_ucx);
#if HAVE_DECL_UCP_PUT_NB
request = ucp_put_nb(ucx_ctx->ucp_peers[dst].ucp_conn, src_addr, size,
(uint64_t)rva, ucx_mkey->rkey, opal_common_ucx_empty_complete_cb);
Expand All @@ -692,7 +685,7 @@ int mca_spml_ucx_put_nb(shmem_ctx_t ctx, void* dst_addr, size_t size, void* src_
spml_ucx_mkey_t *ucx_mkey;
mca_spml_ucx_ctx_t *ucx_ctx = (mca_spml_ucx_ctx_t *)ctx;

ucx_mkey = mca_spml_ucx_get_mkey(ucx_ctx, dst, dst_addr, &rva, &mca_spml_ucx);
ucx_mkey = mca_spml_ucx_get_mkey(ctx, dst, dst_addr, &rva, &mca_spml_ucx);
status = ucp_put_nbi(ucx_ctx->ucp_peers[dst].ucp_conn, src_addr, size,
(uint64_t)rva, ucx_mkey->rkey);

Expand Down
Loading