diff --git a/oshmem/mca/atomic/ucx/atomic_ucx_cswap.c b/oshmem/mca/atomic/ucx/atomic_ucx_cswap.c index fc4c7a33f50..6a5a9901e98 100644 --- a/oshmem/mca/atomic/ucx/atomic_ucx_cswap.c +++ b/oshmem/mca/atomic/ucx/atomic_ucx_cswap.c @@ -30,7 +30,7 @@ int mca_atomic_ucx_cswap(void *target, spml_ucx_mkey_t *ucx_mkey; uint64_t rva; - ucx_mkey = mca_spml_ucx_get_mkey(pe, target, (void *)&rva); + ucx_mkey = mca_spml_ucx_get_mkey(pe, target, (void *)&rva, mca_spml_self); if (NULL == cond) { switch (nlong) { case 4: diff --git a/oshmem/mca/atomic/ucx/atomic_ucx_fadd.c b/oshmem/mca/atomic/ucx/atomic_ucx_fadd.c index a1b88c95deb..053b049bf00 100644 --- a/oshmem/mca/atomic/ucx/atomic_ucx_fadd.c +++ b/oshmem/mca/atomic/ucx/atomic_ucx_fadd.c @@ -29,8 +29,7 @@ int mca_atomic_ucx_fadd(void *target, spml_ucx_mkey_t *ucx_mkey; uint64_t rva; - ucx_mkey = mca_spml_ucx_get_mkey(pe, target, (void *)&rva); - + ucx_mkey = mca_spml_ucx_get_mkey(pe, target, (void *)&rva, mca_spml_self); if (NULL == prev) { switch (nlong) { case 4: diff --git a/oshmem/mca/spml/ucx/Makefile.am b/oshmem/mca/spml/ucx/Makefile.am index 84d8a749250..8cbdb1d9318 100644 --- a/oshmem/mca/spml/ucx/Makefile.am +++ b/oshmem/mca/spml/ucx/Makefile.am @@ -34,7 +34,8 @@ mcacomponentdir = $(ompilibdir) mcacomponent_LTLIBRARIES = $(component_install) mca_spml_ucx_la_SOURCES = $(ucx_sources) mca_spml_ucx_la_LIBADD = $(top_builddir)/oshmem/liboshmem.la \ - $(spml_ucx_LIBS) + $(spml_ucx_LIBS) \ + $(top_builddir)/oshmem/mca/spml/libmca_spml.la mca_spml_ucx_la_LDFLAGS = -module -avoid-version $(spml_ucx_LDFLAGS) noinst_LTLIBRARIES = $(component_noinst) diff --git a/oshmem/mca/spml/ucx/spml_ucx.c b/oshmem/mca/spml/ucx/spml_ucx.c index 42c455f9295..cccf9e4ebe3 100644 --- a/oshmem/mca/spml/ucx/spml_ucx.c +++ b/oshmem/mca/spml/ucx/spml_ucx.c @@ -43,6 +43,9 @@ #define SPML_UCX_PUT_DEBUG 0 #endif +static +spml_ucx_mkey_t * mca_spml_ucx_get_mkey_slow(int pe, void *va, void **rva); + mca_spml_ucx_t mca_spml_ucx = { { /* Init mca_spml_base_module_t */ @@ -74,7 +77,9 @@ mca_spml_ucx_t mca_spml_ucx = { NULL, /* ucp_peers */ 0, /* using_mem_hooks */ 1, /* num_disconnect */ - 0 /* heap_reg_nb */ + 0, /* heap_reg_nb */ + 0, /* enabled */ + mca_spml_ucx_get_mkey_slow }; int mca_spml_ucx_enable(bool enable) @@ -330,6 +335,7 @@ int mca_spml_ucx_add_procs(ompi_proc_t** procs, size_t nprocs) } +static spml_ucx_mkey_t * mca_spml_ucx_get_mkey_slow(int pe, void *va, void **rva) { sshmem_mkey_t *r_mkey; @@ -550,7 +556,7 @@ int mca_spml_ucx_get(void *src_addr, size_t size, void *dst_addr, int src) ucs_status_t status; spml_ucx_mkey_t *ucx_mkey; - ucx_mkey = mca_spml_ucx_get_mkey(src, src_addr, &rva); + ucx_mkey = mca_spml_ucx_get_mkey(src, src_addr, &rva, &mca_spml_ucx); status = ucp_get(mca_spml_ucx.ucp_peers[src].ucp_conn, dst_addr, size, (uint64_t)rva, ucx_mkey->rkey); @@ -563,7 +569,7 @@ int mca_spml_ucx_get_nb(void *src_addr, size_t size, void *dst_addr, int src, vo ucs_status_t status; spml_ucx_mkey_t *ucx_mkey; - ucx_mkey = mca_spml_ucx_get_mkey(src, src_addr, &rva); + ucx_mkey = mca_spml_ucx_get_mkey(src, src_addr, &rva, &mca_spml_ucx); status = ucp_get_nbi(mca_spml_ucx.ucp_peers[src].ucp_conn, dst_addr, size, (uint64_t)rva, ucx_mkey->rkey); @@ -576,7 +582,7 @@ int mca_spml_ucx_put(void* dst_addr, size_t size, void* src_addr, int dst) ucs_status_t status; spml_ucx_mkey_t *ucx_mkey; - ucx_mkey = mca_spml_ucx_get_mkey(dst, dst_addr, &rva); + ucx_mkey = mca_spml_ucx_get_mkey(dst, dst_addr, &rva, &mca_spml_ucx); status = ucp_put(mca_spml_ucx.ucp_peers[dst].ucp_conn, src_addr, size, (uint64_t)rva, ucx_mkey->rkey); @@ -589,7 +595,7 @@ int mca_spml_ucx_put_nb(void* dst_addr, size_t size, void* src_addr, int dst, vo ucs_status_t status; spml_ucx_mkey_t *ucx_mkey; - ucx_mkey = mca_spml_ucx_get_mkey(dst, dst_addr, &rva); + ucx_mkey = mca_spml_ucx_get_mkey(dst, dst_addr, &rva, &mca_spml_ucx); status = ucp_put_nbi(mca_spml_ucx.ucp_peers[dst].ucp_conn, src_addr, size, (uint64_t)rva, ucx_mkey->rkey); diff --git a/oshmem/mca/spml/ucx/spml_ucx.h b/oshmem/mca/spml/ucx/spml_ucx.h index b57850414bb..4aeed1481f3 100644 --- a/oshmem/mca/spml/ucx/spml_ucx.h +++ b/oshmem/mca/spml/ucx/spml_ucx.h @@ -58,6 +58,8 @@ struct ucp_peer { }; typedef struct ucp_peer ucp_peer_t; +typedef spml_ucx_mkey_t * (*mca_spml_ucx_get_mkey_slow_fn_t)(int pe, void *va, void **rva); + struct mca_spml_ucx { mca_spml_base_module_t super; ucp_context_h ucp_context; @@ -68,6 +70,8 @@ struct mca_spml_ucx { int priority; /* component priority */ bool enabled; + + mca_spml_ucx_get_mkey_slow_fn_t get_mkey_slow; }; typedef struct mca_spml_ucx mca_spml_ucx_t; @@ -121,17 +125,16 @@ extern int mca_spml_ucx_quiet(void); extern int spml_ucx_progress(void); -spml_ucx_mkey_t * mca_spml_ucx_get_mkey_slow(int pe, void *va, void **rva); - static inline spml_ucx_mkey_t * -mca_spml_ucx_get_mkey(int pe, void *va, void **rva) +mca_spml_ucx_get_mkey(int pe, void *va, void **rva, mca_spml_ucx_t* module) { spml_ucx_cached_mkey_t *mkey; - mkey = mca_spml_ucx.ucp_peers[pe].mkeys; + mkey = module->ucp_peers[pe].mkeys; mkey = (spml_ucx_cached_mkey_t *)map_segment_find_va(&mkey->super.super, sizeof(*mkey), va); if (OPAL_UNLIKELY(NULL == mkey)) { - return mca_spml_ucx_get_mkey_slow(pe, va, rva); + assert(module->get_mkey_slow); + return module->get_mkey_slow(pe, va, rva); } *rva = map_segment_va2rva(&mkey->super, va); return &mkey->key;