2222#include "opal/datatype/opal_convertor.h"
2323#include "opal/mca/common/ucx/common_ucx.h"
2424#include "opal/util/opal_environ.h"
25+ #include "opal/util/minmax.h"
2526#include "ompi/datatype/ompi_datatype.h"
2627#include "ompi/mca/pml/pml.h"
2728
@@ -126,6 +127,171 @@ static ucp_request_param_t mca_spml_ucx_request_param_b = {
126127};
127128#endif
128129
130+ unsigned
131+ mca_spml_ucx_mem_map_flags_symmetric_rkey (struct mca_spml_ucx * spml_ucx )
132+ {
133+ #if HAVE_DECL_UCP_MEM_MAP_SYMMETRIC_RKEY
134+ if (spml_ucx -> symmetric_rkey_max_count > 0 ) {
135+ return UCP_MEM_MAP_SYMMETRIC_RKEY ;
136+ }
137+ #endif
138+
139+ return 0 ;
140+ }
141+
142+ void mca_spml_ucx_rkey_store_init (mca_spml_ucx_rkey_store_t * store )
143+ {
144+ store -> array = NULL ;
145+ store -> count = 0 ;
146+ store -> size = 0 ;
147+ }
148+
149+ void mca_spml_ucx_rkey_store_cleanup (mca_spml_ucx_rkey_store_t * store )
150+ {
151+ int i ;
152+
153+ for (i = 0 ; i < store -> count ; i ++ ) {
154+ if (store -> array [i ].refcnt != 0 ) {
155+ SPML_UCX_ERROR ("rkey store destroy: %d/%d has refcnt %d > 0" ,
156+ i , store -> count , store -> array [i ].refcnt );
157+ }
158+
159+ ucp_rkey_destroy (store -> array [i ].rkey );
160+ }
161+
162+ free (store -> array );
163+ }
164+
165+ /**
166+ * Find position in sorted array for existing or future entry
167+ *
168+ * @param[in] store Store of the entries
169+ * @param[in] worker Common worker for rkeys used
170+ * @param[in] rkey Remote key to search for
171+ * @param[out] index Index of entry
172+ *
173+ * @return
174+ * OSHMEM_ERR_NOT_FOUND: index contains the position where future element
175+ * should be inserted to keep array sorted
176+ * OSHMEM_SUCCESS : index contains the position of the element
177+ * Other error : index is not valid
178+ */
179+ static int mca_spml_ucx_rkey_store_find (const mca_spml_ucx_rkey_store_t * store ,
180+ const ucp_worker_h worker ,
181+ const ucp_rkey_h rkey ,
182+ int * index )
183+ {
184+ #if HAVE_DECL_UCP_RKEY_COMPARE
185+ ucp_rkey_compare_params_t params ;
186+ int i , result , m , end ;
187+ ucs_status_t status ;
188+
189+ for (i = 0 , end = store -> count ; i < end ;) {
190+ m = (i + end ) / 2 ;
191+
192+ params .field_mask = 0 ;
193+ status = ucp_rkey_compare (worker , store -> array [m ].rkey ,
194+ rkey , & params , & result );
195+ if (status != UCS_OK ) {
196+ return OSHMEM_ERROR ;
197+ } else if (result == 0 ) {
198+ * index = m ;
199+ return OSHMEM_SUCCESS ;
200+ } else if (result > 0 ) {
201+ end = m ;
202+ } else {
203+ i = m + 1 ;
204+ }
205+ }
206+
207+ * index = i ;
208+ return OSHMEM_ERR_NOT_FOUND ;
209+ #else
210+ return OSHMEM_ERROR ;
211+ #endif
212+ }
213+
214+ static void mca_spml_ucx_rkey_store_insert (mca_spml_ucx_rkey_store_t * store ,
215+ int i , ucp_rkey_h rkey )
216+ {
217+ int size ;
218+ mca_spml_ucx_rkey_t * tmp ;
219+
220+ if (store -> count >= mca_spml_ucx .symmetric_rkey_max_count ) {
221+ return ;
222+ }
223+
224+ if (store -> count >= store -> size ) {
225+ size = opal_min (opal_max (store -> size , 8 ) * 2 ,
226+ mca_spml_ucx .symmetric_rkey_max_count );
227+ tmp = realloc (store -> array , size * sizeof (* store -> array ));
228+ if (tmp == NULL ) {
229+ return ;
230+ }
231+
232+ store -> array = tmp ;
233+ store -> size = size ;
234+ }
235+
236+ memmove (& store -> array [i + 1 ], & store -> array [i ],
237+ (store -> count - i ) * sizeof (* store -> array ));
238+ store -> array [i ].rkey = rkey ;
239+ store -> array [i ].refcnt = 1 ;
240+ store -> count ++ ;
241+ return ;
242+ }
243+
244+ /* Takes ownership of input ucp remote key */
245+ static ucp_rkey_h mca_spml_ucx_rkey_store_get (mca_spml_ucx_rkey_store_t * store ,
246+ ucp_worker_h worker ,
247+ ucp_rkey_h rkey )
248+ {
249+ int ret , i ;
250+
251+ if (mca_spml_ucx .symmetric_rkey_max_count == 0 ) {
252+ return rkey ;
253+ }
254+
255+ ret = mca_spml_ucx_rkey_store_find (store , worker , rkey , & i );
256+ if (ret == OSHMEM_SUCCESS ) {
257+ ucp_rkey_destroy (rkey );
258+ store -> array [i ].refcnt ++ ;
259+ return store -> array [i ].rkey ;
260+ }
261+
262+ if (ret == OSHMEM_ERR_NOT_FOUND ) {
263+ mca_spml_ucx_rkey_store_insert (store , i , rkey );
264+ }
265+
266+ return rkey ;
267+ }
268+
269+ static void mca_spml_ucx_rkey_store_put (mca_spml_ucx_rkey_store_t * store ,
270+ ucp_worker_h worker ,
271+ ucp_rkey_h rkey )
272+ {
273+ mca_spml_ucx_rkey_t * entry ;
274+ int ret , i ;
275+
276+ ret = mca_spml_ucx_rkey_store_find (store , worker , rkey , & i );
277+ if (ret != OSHMEM_SUCCESS ) {
278+ goto out ;
279+ }
280+
281+ entry = & store -> array [i ];
282+ assert (entry -> rkey == rkey );
283+ if (-- entry -> refcnt > 0 ) {
284+ return ;
285+ }
286+
287+ memmove (& store -> array [i ], & store -> array [i + 1 ],
288+ (store -> count - (i + 1 )) * sizeof (* store -> array ));
289+ store -> count -- ;
290+
291+ out :
292+ ucp_rkey_destroy (rkey );
293+ }
294+
129295int mca_spml_ucx_enable (bool enable )
130296{
131297 SPML_UCX_VERBOSE (50 , "*** ucx ENABLED ****" );
@@ -240,6 +406,7 @@ int mca_spml_ucx_ctx_mkey_add(mca_spml_ucx_ctx_t *ucx_ctx, int pe, uint32_t segn
240406{
241407 int rc ;
242408 ucs_status_t err ;
409+ ucp_rkey_h rkey ;
243410
244411 rc = mca_spml_ucx_ctx_mkey_new (ucx_ctx , pe , segno , ucx_mkey );
245412 if (OSHMEM_SUCCESS != rc ) {
@@ -248,11 +415,18 @@ int mca_spml_ucx_ctx_mkey_add(mca_spml_ucx_ctx_t *ucx_ctx, int pe, uint32_t segn
248415 }
249416
250417 if (mkey -> u .data ) {
251- err = ucp_ep_rkey_unpack (ucx_ctx -> ucp_peers [pe ].ucp_conn , mkey -> u .data , & (( * ucx_mkey ) -> rkey ) );
418+ err = ucp_ep_rkey_unpack (ucx_ctx -> ucp_peers [pe ].ucp_conn , mkey -> u .data , & rkey );
252419 if (UCS_OK != err ) {
253420 SPML_UCX_ERROR ("failed to unpack rkey: %s" , ucs_status_string (err ));
254421 return OSHMEM_ERROR ;
255422 }
423+
424+ if (!oshmem_proc_on_local_node (pe )) {
425+ rkey = mca_spml_ucx_rkey_store_get (& ucx_ctx -> rkey_store , ucx_ctx -> ucp_worker [0 ], rkey );
426+ }
427+
428+ (* ucx_mkey )-> rkey = rkey ;
429+
256430 rc = mca_spml_ucx_ctx_mkey_cache (ucx_ctx , mkey , segno , pe );
257431 if (OSHMEM_SUCCESS != rc ) {
258432 SPML_UCX_ERROR ("mca_spml_ucx_ctx_mkey_cache failed" );
@@ -267,7 +441,7 @@ int mca_spml_ucx_ctx_mkey_del(mca_spml_ucx_ctx_t *ucx_ctx, int pe, uint32_t segn
267441 ucp_peer_t * ucp_peer ;
268442 int rc ;
269443 ucp_peer = & (ucx_ctx -> ucp_peers [pe ]);
270- ucp_rkey_destroy ( ucx_mkey -> rkey );
444+ mca_spml_ucx_rkey_store_put ( & ucx_ctx -> rkey_store , ucx_ctx -> ucp_worker [ 0 ], ucx_mkey -> rkey );
271445 ucx_mkey -> rkey = NULL ;
272446 rc = mca_spml_ucx_peer_mkey_cache_del (ucp_peer , segno );
273447 if (OSHMEM_SUCCESS != rc ){
@@ -725,7 +899,8 @@ sshmem_mkey_t *mca_spml_ucx_register(void* addr,
725899 UCP_MEM_MAP_PARAM_FIELD_FLAGS ;
726900 mem_map_params .address = addr ;
727901 mem_map_params .length = size ;
728- mem_map_params .flags = flags ;
902+ mem_map_params .flags = flags |
903+ mca_spml_ucx_mem_map_flags_symmetric_rkey (& mca_spml_ucx );
729904
730905 status = ucp_mem_map (mca_spml_ucx .ucp_context , & mem_map_params , & mem_h );
731906 if (UCS_OK != status ) {
@@ -917,6 +1092,8 @@ static int mca_spml_ucx_ctx_create_common(long options, mca_spml_ucx_ctx_t **ucx
9171092 }
9181093 }
9191094
1095+ mca_spml_ucx_rkey_store_init (& ucx_ctx -> rkey_store );
1096+
9201097 * ucx_ctx_p = ucx_ctx ;
9211098
9221099 return OSHMEM_SUCCESS ;
0 commit comments