Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions config/ompi_check_ucx.m4
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,8 @@ AC_DEFUN([OMPI_CHECK_UCX],[
UCP_ATOMIC_FETCH_OP_FXOR,
UCP_PARAM_FIELD_ESTIMATED_NUM_PPN,
UCP_WORKER_FLAG_IGNORE_REQUEST_LEAK,
UCP_OP_ATTR_FLAG_MULTI_SEND],
UCP_OP_ATTR_FLAG_MULTI_SEND,
UCP_MEM_MAP_SYMMETRIC_RKEY],
[], [],
[#include <ucp/api/ucp.h>])
AC_CHECK_DECLS([UCP_WORKER_ATTR_FIELD_ADDRESS_FLAGS],
Expand All @@ -123,7 +124,8 @@ AC_DEFUN([OMPI_CHECK_UCX],[
[#include <ucp/api/ucp.h>])
AC_CHECK_DECLS([ucp_tag_send_nbx,
ucp_tag_send_sync_nbx,
ucp_tag_recv_nbx],
ucp_tag_recv_nbx,
ucp_rkey_compare],
[], [],
[#include <ucp/api/ucp.h>])
AC_CHECK_TYPES([ucp_request_param_t],
Expand Down
183 changes: 180 additions & 3 deletions oshmem/mca/spml/ucx/spml_ucx.c
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include "opal/datatype/opal_convertor.h"
#include "opal/mca/common/ucx/common_ucx.h"
#include "opal/util/opal_environ.h"
#include "opal/util/minmax.h"
#include "ompi/datatype/ompi_datatype.h"
#include "ompi/mca/pml/pml.h"

Expand Down Expand Up @@ -126,6 +127,171 @@ static ucp_request_param_t mca_spml_ucx_request_param_b = {
};
#endif

unsigned
mca_spml_ucx_mem_map_flags_symmetric_rkey(struct mca_spml_ucx *spml_ucx)
{
#if HAVE_DECL_UCP_MEM_MAP_SYMMETRIC_RKEY
if (spml_ucx->symmetric_rkey_max_count > 0) {
return UCP_MEM_MAP_SYMMETRIC_RKEY;
}
#endif

return 0;
}

void mca_spml_ucx_rkey_store_init(mca_spml_ucx_rkey_store_t *store)
{
store->array = NULL;
store->count = 0;
store->size = 0;
}

void mca_spml_ucx_rkey_store_cleanup(mca_spml_ucx_rkey_store_t *store)
{
int i;

for (i = 0; i < store->count; i++) {
if (store->array[i].refcnt != 0) {
SPML_UCX_ERROR("rkey store destroy: %d/%d has refcnt %d > 0",
i, store->count, store->array[i].refcnt);
}

ucp_rkey_destroy(store->array[i].rkey);
}

free(store->array);
}

/**
* Find position in sorted array for existing or future entry
*
* @param[in] store Store of the entries
* @param[in] worker Common worker for rkeys used
* @param[in] rkey Remote key to search for
* @param[out] index Index of entry
*
* @return
* OSHMEM_ERR_NOT_FOUND: index contains the position where future element
* should be inserted to keep array sorted
* OSHMEM_SUCCESS : index contains the position of the element
* Other error : index is not valid
*/
static int mca_spml_ucx_rkey_store_find(const mca_spml_ucx_rkey_store_t *store,
const ucp_worker_h worker,
const ucp_rkey_h rkey,
int *index)
{
#if HAVE_DECL_UCP_RKEY_COMPARE
ucp_rkey_compare_params_t params;
int i, result, m, end;
ucs_status_t status;

for (i = 0, end = store->count; i < end;) {
m = (i + end) / 2;

params.field_mask = 0;
status = ucp_rkey_compare(worker, store->array[m].rkey,
rkey, &params, &result);
if (status != UCS_OK) {
return OSHMEM_ERROR;
} else if (result == 0) {
*index = m;
return OSHMEM_SUCCESS;
} else if (result > 0) {
end = m;
} else {
i = m + 1;
}
}

*index = i;
return OSHMEM_ERR_NOT_FOUND;
#else
return OSHMEM_ERROR;
#endif
}

static void mca_spml_ucx_rkey_store_insert(mca_spml_ucx_rkey_store_t *store,
int i, ucp_rkey_h rkey)
{
int size;
mca_spml_ucx_rkey_t *tmp;

if (store->count >= mca_spml_ucx.symmetric_rkey_max_count) {
return;
}

if (store->count >= store->size) {
size = opal_min(opal_max(store->size, 8) * 2,
mca_spml_ucx.symmetric_rkey_max_count);
tmp = realloc(store->array, size * sizeof(*store->array));
if (tmp == NULL) {
return;
}

store->array = tmp;
store->size = size;
}

memmove(&store->array[i + 1], &store->array[i],
(store->count - i) * sizeof(*store->array));
store->array[i].rkey = rkey;
store->array[i].refcnt = 1;
store->count++;
return;
}

/* Takes ownership of input ucp remote key */
static ucp_rkey_h mca_spml_ucx_rkey_store_get(mca_spml_ucx_rkey_store_t *store,
ucp_worker_h worker,
ucp_rkey_h rkey)
{
int ret, i;

if (mca_spml_ucx.symmetric_rkey_max_count == 0) {
return rkey;
}

ret = mca_spml_ucx_rkey_store_find(store, worker, rkey, &i);
if (ret == OSHMEM_SUCCESS) {
ucp_rkey_destroy(rkey);
store->array[i].refcnt++;
return store->array[i].rkey;
}

if (ret == OSHMEM_ERR_NOT_FOUND) {
mca_spml_ucx_rkey_store_insert(store, i, rkey);
}

return rkey;
}

static void mca_spml_ucx_rkey_store_put(mca_spml_ucx_rkey_store_t *store,
ucp_worker_h worker,
ucp_rkey_h rkey)
{
mca_spml_ucx_rkey_t *entry;
int ret, i;

ret = mca_spml_ucx_rkey_store_find(store, worker, rkey, &i);
if (ret != OSHMEM_SUCCESS) {
goto out;
}

entry = &store->array[i];
assert(entry->rkey == rkey);
if (--entry->refcnt > 0) {
return;
}

memmove(&store->array[i], &store->array[i + 1],
(store->count - (i + 1)) * sizeof(*store->array));
store->count--;

out:
ucp_rkey_destroy(rkey);
}

int mca_spml_ucx_enable(bool enable)
{
SPML_UCX_VERBOSE(50, "*** ucx ENABLED ****");
Expand Down Expand Up @@ -240,6 +406,7 @@ int mca_spml_ucx_ctx_mkey_add(mca_spml_ucx_ctx_t *ucx_ctx, int pe, uint32_t segn
{
int rc;
ucs_status_t err;
ucp_rkey_h rkey;

rc = mca_spml_ucx_ctx_mkey_new(ucx_ctx, pe, segno, ucx_mkey);
if (OSHMEM_SUCCESS != rc) {
Expand All @@ -248,11 +415,18 @@ int mca_spml_ucx_ctx_mkey_add(mca_spml_ucx_ctx_t *ucx_ctx, int pe, uint32_t segn
}

if (mkey->u.data) {
err = ucp_ep_rkey_unpack(ucx_ctx->ucp_peers[pe].ucp_conn, mkey->u.data, &((*ucx_mkey)->rkey));
err = ucp_ep_rkey_unpack(ucx_ctx->ucp_peers[pe].ucp_conn, mkey->u.data, &rkey);
if (UCS_OK != err) {
SPML_UCX_ERROR("failed to unpack rkey: %s", ucs_status_string(err));
return OSHMEM_ERROR;
}

if (!oshmem_proc_on_local_node(pe)) {
rkey = mca_spml_ucx_rkey_store_get(&ucx_ctx->rkey_store, ucx_ctx->ucp_worker[0], rkey);
}

(*ucx_mkey)->rkey = rkey;

rc = mca_spml_ucx_ctx_mkey_cache(ucx_ctx, mkey, segno, pe);
if (OSHMEM_SUCCESS != rc) {
SPML_UCX_ERROR("mca_spml_ucx_ctx_mkey_cache failed");
Expand All @@ -267,7 +441,7 @@ int mca_spml_ucx_ctx_mkey_del(mca_spml_ucx_ctx_t *ucx_ctx, int pe, uint32_t segn
ucp_peer_t *ucp_peer;
int rc;
ucp_peer = &(ucx_ctx->ucp_peers[pe]);
ucp_rkey_destroy(ucx_mkey->rkey);
mca_spml_ucx_rkey_store_put(&ucx_ctx->rkey_store, ucx_ctx->ucp_worker[0], ucx_mkey->rkey);
ucx_mkey->rkey = NULL;
rc = mca_spml_ucx_peer_mkey_cache_del(ucp_peer, segno);
if(OSHMEM_SUCCESS != rc){
Expand Down Expand Up @@ -725,7 +899,8 @@ sshmem_mkey_t *mca_spml_ucx_register(void* addr,
UCP_MEM_MAP_PARAM_FIELD_FLAGS;
mem_map_params.address = addr;
mem_map_params.length = size;
mem_map_params.flags = flags;
mem_map_params.flags = flags |
mca_spml_ucx_mem_map_flags_symmetric_rkey(&mca_spml_ucx);

status = ucp_mem_map(mca_spml_ucx.ucp_context, &mem_map_params, &mem_h);
if (UCS_OK != status) {
Expand Down Expand Up @@ -917,6 +1092,8 @@ static int mca_spml_ucx_ctx_create_common(long options, mca_spml_ucx_ctx_t **ucx
}
}

mca_spml_ucx_rkey_store_init(&ucx_ctx->rkey_store);

*ucx_ctx_p = ucx_ctx;

return OSHMEM_SUCCESS;
Expand Down
41 changes: 30 additions & 11 deletions oshmem/mca/spml/ucx/spml_ucx.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,18 +76,31 @@ struct ucp_peer {
size_t mkeys_cnt;
};
typedef struct ucp_peer ucp_peer_t;


/* An rkey_store entry */
typedef struct mca_spml_ucx_rkey {
ucp_rkey_h rkey;
int refcnt;
} mca_spml_ucx_rkey_t;

typedef struct mca_spml_ucx_rkey_store {
mca_spml_ucx_rkey_t *array;
int size;
int count;
} mca_spml_ucx_rkey_store_t;

struct mca_spml_ucx_ctx {
ucp_worker_h *ucp_worker;
ucp_peer_t *ucp_peers;
long options;
opal_bitmap_t put_op_bitmap;
unsigned long nb_progress_cnt;
unsigned int ucp_workers;
int *put_proc_indexes;
unsigned put_proc_count;
bool synchronized_quiet;
int strong_sync;
ucp_worker_h *ucp_worker;
ucp_peer_t *ucp_peers;
long options;
opal_bitmap_t put_op_bitmap;
unsigned long nb_progress_cnt;
unsigned int ucp_workers;
int *put_proc_indexes;
unsigned put_proc_count;
bool synchronized_quiet;
int strong_sync;
mca_spml_ucx_rkey_store_t rkey_store;
};
typedef struct mca_spml_ucx_ctx mca_spml_ucx_ctx_t;

Expand Down Expand Up @@ -128,6 +141,7 @@ struct mca_spml_ucx {
unsigned long nb_ucp_worker_progress;
unsigned int ucp_workers;
unsigned int ucp_worker_cnt;
int symmetric_rkey_max_count;
};
typedef struct mca_spml_ucx mca_spml_ucx_t;

Expand Down Expand Up @@ -280,6 +294,11 @@ extern int mca_spml_ucx_team_fcollect(shmem_team_t team, void
extern int mca_spml_ucx_team_reduce(shmem_team_t team, void
*dest, const void *source, size_t nreduce, int operation, int datatype);

extern unsigned
mca_spml_ucx_mem_map_flags_symmetric_rkey(struct mca_spml_ucx *spml_ucx);

extern void mca_spml_ucx_rkey_store_init(mca_spml_ucx_rkey_store_t *store);
extern void mca_spml_ucx_rkey_store_cleanup(mca_spml_ucx_rkey_store_t *store);

static inline int
mca_spml_ucx_peer_mkey_get(ucp_peer_t *ucp_peer, int index, spml_ucx_cached_mkey_t **out_rmkey)
Expand Down
Loading