1616
1717#include "opal/runtime/opal.h"
1818#include "opal/mca/pmix/pmix.h"
19+ #include "ompi/attribute/attribute.h"
1920#include "ompi/message/message.h"
2021#include "ompi/mca/pml/base/pml_base_bsend.h"
2122#include "pml_ucx_request.h"
@@ -184,9 +185,9 @@ int mca_pml_ucx_close(void)
184185int mca_pml_ucx_init (void )
185186{
186187 ucp_worker_params_t params ;
187- ucs_status_t status ;
188188 ucp_worker_attr_t attr ;
189- int rc ;
189+ ucs_status_t status ;
190+ int i , rc ;
190191
191192 PML_UCX_VERBOSE (1 , "mca_pml_ucx_init" );
192193
@@ -203,30 +204,34 @@ int mca_pml_ucx_init(void)
203204 & ompi_pml_ucx .ucp_worker );
204205 if (UCS_OK != status ) {
205206 PML_UCX_ERROR ("Failed to create UCP worker" );
206- return OMPI_ERROR ;
207+ rc = OMPI_ERROR ;
208+ goto err ;
207209 }
208210
209211 attr .field_mask = UCP_WORKER_ATTR_FIELD_THREAD_MODE ;
210212 status = ucp_worker_query (ompi_pml_ucx .ucp_worker , & attr );
211213 if (UCS_OK != status ) {
212- ucp_worker_destroy (ompi_pml_ucx .ucp_worker );
213- ompi_pml_ucx .ucp_worker = NULL ;
214214 PML_UCX_ERROR ("Failed to query UCP worker thread level" );
215- return OMPI_ERROR ;
215+ rc = OMPI_ERROR ;
216+ goto err_destroy_worker ;
216217 }
217218
218- if (ompi_mpi_thread_multiple && attr .thread_mode != UCS_THREAD_MODE_MULTI ) {
219+ if (ompi_mpi_thread_multiple && ( attr .thread_mode != UCS_THREAD_MODE_MULTI ) ) {
219220 /* UCX does not support multithreading, disqualify current PML for now */
220221 /* TODO: we should let OMPI to fallback to THREAD_SINGLE mode */
221- ucp_worker_destroy (ompi_pml_ucx .ucp_worker );
222- ompi_pml_ucx .ucp_worker = NULL ;
223222 PML_UCX_ERROR ("UCP worker does not support MPI_THREAD_MULTIPLE" );
224- return OMPI_ERROR ;
223+ rc = OMPI_ERR_NOT_SUPPORTED ;
224+ goto err_destroy_worker ;
225225 }
226226
227227 rc = mca_pml_ucx_send_worker_address ();
228228 if (rc < 0 ) {
229- return rc ;
229+ goto err_destroy_worker ;
230+ }
231+
232+ ompi_pml_ucx .datatype_attr_keyval = MPI_KEYVAL_INVALID ;
233+ for (i = 0 ; i < OMPI_DATATYPE_MAX_PREDEFINED ; ++ i ) {
234+ ompi_pml_ucx .predefined_types [i ] = PML_UCX_DATATYPE_INVALID ;
230235 }
231236
232237 /* Initialize the free lists */
@@ -242,15 +247,34 @@ int mca_pml_ucx_init(void)
242247 PML_UCX_VERBOSE (2 , "created ucp context %p, worker %p" ,
243248 (void * )ompi_pml_ucx .ucp_context ,
244249 (void * )ompi_pml_ucx .ucp_worker );
245- return OMPI_SUCCESS ;
250+ return rc ;
251+
252+ err_destroy_worker :
253+ ucp_worker_destroy (ompi_pml_ucx .ucp_worker );
254+ ompi_pml_ucx .ucp_worker = NULL ;
255+ err :
256+ return OMPI_ERROR ;
246257}
247258
248259int mca_pml_ucx_cleanup (void )
249260{
261+ int i ;
262+
250263 PML_UCX_VERBOSE (1 , "mca_pml_ucx_cleanup" );
251264
252265 opal_progress_unregister (mca_pml_ucx_progress );
253266
267+ if (ompi_pml_ucx .datatype_attr_keyval != MPI_KEYVAL_INVALID ) {
268+ ompi_attr_free_keyval (TYPE_ATTR , & ompi_pml_ucx .datatype_attr_keyval , false);
269+ }
270+
271+ for (i = 0 ; i < OMPI_DATATYPE_MAX_PREDEFINED ; ++ i ) {
272+ if (ompi_pml_ucx .predefined_types [i ] != PML_UCX_DATATYPE_INVALID ) {
273+ ucp_dt_destroy (ompi_pml_ucx .predefined_types [i ]);
274+ ompi_pml_ucx .predefined_types [i ] = PML_UCX_DATATYPE_INVALID ;
275+ }
276+ }
277+
254278 ompi_pml_ucx .completed_send_req .req_state = OMPI_REQUEST_INVALID ;
255279 OMPI_REQUEST_FINI (& ompi_pml_ucx .completed_send_req );
256280 OBJ_DESTRUCT (& ompi_pml_ucx .completed_send_req );
@@ -448,6 +472,22 @@ int mca_pml_ucx_del_procs(struct ompi_proc_t **procs, size_t nprocs)
448472
449473int mca_pml_ucx_enable (bool enable )
450474{
475+ ompi_attribute_fn_ptr_union_t copy_fn ;
476+ ompi_attribute_fn_ptr_union_t del_fn ;
477+ int ret ;
478+
479+ /* Create a key for adding custom attributes to datatypes */
480+ copy_fn .attr_datatype_copy_fn =
481+ (MPI_Type_internal_copy_attr_function * )MPI_TYPE_NULL_COPY_FN ;
482+ del_fn .attr_datatype_delete_fn = mca_pml_ucx_datatype_attr_del_fn ;
483+ ret = ompi_attr_create_keyval (TYPE_ATTR , copy_fn , del_fn ,
484+ & ompi_pml_ucx .datatype_attr_keyval , NULL , 0 ,
485+ NULL );
486+ if (ret != OMPI_SUCCESS ) {
487+ PML_UCX_ERROR ("Failed to create keyval for UCX datatypes: %d" , ret );
488+ return ret ;
489+ }
490+
451491 PML_UCX_FREELIST_INIT (& ompi_pml_ucx .persistent_reqs ,
452492 mca_pml_ucx_persistent_request_t ,
453493 128 , -1 , 128 );
0 commit comments