1616
1717#include  "opal/runtime/opal.h" 
1818#include  "opal/mca/pmix/pmix.h" 
19+ #include  "ompi/attribute/attribute.h" 
1920#include  "ompi/message/message.h" 
2021#include  "ompi/mca/pml/base/pml_base_bsend.h" 
2122#include  "pml_ucx_request.h" 
@@ -184,9 +185,9 @@ int mca_pml_ucx_close(void)
184185int  mca_pml_ucx_init (void )
185186{
186187    ucp_worker_params_t  params ;
187-     ucs_status_t  status ;
188188    ucp_worker_attr_t  attr ;
189-     int  rc ;
189+     ucs_status_t  status ;
190+     int  i , rc ;
190191
191192    PML_UCX_VERBOSE (1 , "mca_pml_ucx_init" );
192193
@@ -203,30 +204,34 @@ int mca_pml_ucx_init(void)
203204                               & ompi_pml_ucx .ucp_worker );
204205    if  (UCS_OK  !=  status ) {
205206        PML_UCX_ERROR ("Failed to create UCP worker" );
206-         return  OMPI_ERROR ;
207+         rc  =  OMPI_ERROR ;
208+         goto err ;
207209    }
208210
209211    attr .field_mask  =  UCP_WORKER_ATTR_FIELD_THREAD_MODE ;
210212    status  =  ucp_worker_query (ompi_pml_ucx .ucp_worker , & attr );
211213    if  (UCS_OK  !=  status ) {
212-         ucp_worker_destroy (ompi_pml_ucx .ucp_worker );
213-         ompi_pml_ucx .ucp_worker  =  NULL ;
214214        PML_UCX_ERROR ("Failed to query UCP worker thread level" );
215-         return  OMPI_ERROR ;
215+         rc  =  OMPI_ERROR ;
216+         goto err_destroy_worker ;
216217    }
217218
218-     if  (ompi_mpi_thread_multiple  &&  attr .thread_mode  !=  UCS_THREAD_MODE_MULTI ) {
219+     if  (ompi_mpi_thread_multiple  &&  ( attr .thread_mode  !=  UCS_THREAD_MODE_MULTI ) ) {
219220        /* UCX does not support multithreading, disqualify current PML for now */ 
220221        /* TODO: we should let OMPI to fallback to THREAD_SINGLE mode */ 
221-         ucp_worker_destroy (ompi_pml_ucx .ucp_worker );
222-         ompi_pml_ucx .ucp_worker  =  NULL ;
223222        PML_UCX_ERROR ("UCP worker does not support MPI_THREAD_MULTIPLE" );
224-         return  OMPI_ERROR ;
223+         rc  =  OMPI_ERR_NOT_SUPPORTED ;
224+         goto err_destroy_worker ;
225225    }
226226
227227    rc  =  mca_pml_ucx_send_worker_address ();
228228    if  (rc  <  0 ) {
229-         return  rc ;
229+         goto err_destroy_worker ;
230+     }
231+ 
232+     ompi_pml_ucx .datatype_attr_keyval  =  MPI_KEYVAL_INVALID ;
233+     for  (i  =  0 ; i  <  OMPI_DATATYPE_MAX_PREDEFINED ; ++ i ) {
234+         ompi_pml_ucx .predefined_types [i ] =  PML_UCX_DATATYPE_INVALID ;
230235    }
231236
232237    /* Initialize the free lists */ 
@@ -243,14 +248,33 @@ int mca_pml_ucx_init(void)
243248                    (void  * )ompi_pml_ucx .ucp_context ,
244249                    (void  * )ompi_pml_ucx .ucp_worker );
245250    return  OMPI_SUCCESS ;
251+ 
252+ err_destroy_worker :
253+     ucp_worker_destroy (ompi_pml_ucx .ucp_worker );
254+     ompi_pml_ucx .ucp_worker  =  NULL ;
255+ err :
256+     return  OMPI_ERROR ;
246257}
247258
248259int  mca_pml_ucx_cleanup (void )
249260{
261+     int  i ;
262+ 
250263    PML_UCX_VERBOSE (1 , "mca_pml_ucx_cleanup" );
251264
252265    opal_progress_unregister (mca_pml_ucx_progress );
253266
267+     if  (ompi_pml_ucx .datatype_attr_keyval  !=  MPI_KEYVAL_INVALID ) {
268+         ompi_attr_free_keyval (TYPE_ATTR , & ompi_pml_ucx .datatype_attr_keyval , false);
269+     }
270+ 
271+     for  (i  =  0 ; i  <  OMPI_DATATYPE_MAX_PREDEFINED ; ++ i ) {
272+         if  (ompi_pml_ucx .predefined_types [i ] !=  PML_UCX_DATATYPE_INVALID ) {
273+             ucp_dt_destroy (ompi_pml_ucx .predefined_types [i ]);
274+             ompi_pml_ucx .predefined_types [i ] =  PML_UCX_DATATYPE_INVALID ;
275+         }
276+     }
277+ 
254278    ompi_pml_ucx .completed_send_req .req_state  =  OMPI_REQUEST_INVALID ;
255279    OMPI_REQUEST_FINI (& ompi_pml_ucx .completed_send_req );
256280    OBJ_DESTRUCT (& ompi_pml_ucx .completed_send_req );
@@ -448,6 +472,22 @@ int mca_pml_ucx_del_procs(struct ompi_proc_t **procs, size_t nprocs)
448472
449473int  mca_pml_ucx_enable (bool  enable )
450474{
475+     ompi_attribute_fn_ptr_union_t  copy_fn ;
476+     ompi_attribute_fn_ptr_union_t  del_fn ;
477+     int  ret ;
478+ 
479+     /* Create a key for adding custom attributes to datatypes */ 
480+     copy_fn .attr_datatype_copy_fn   = 
481+                     (MPI_Type_internal_copy_attr_function * )MPI_TYPE_NULL_COPY_FN ;
482+     del_fn .attr_datatype_delete_fn  =  mca_pml_ucx_datatype_attr_del_fn ;
483+     ret  =  ompi_attr_create_keyval (TYPE_ATTR , copy_fn , del_fn ,
484+                                   & ompi_pml_ucx .datatype_attr_keyval , NULL , 0 ,
485+                                   NULL );
486+     if  (ret  !=  OMPI_SUCCESS ) {
487+         PML_UCX_ERROR ("Failed to create keyval for UCX datatypes: %d" , ret );
488+         return  ret ;
489+     }
490+ 
451491    PML_UCX_FREELIST_INIT (& ompi_pml_ucx .persistent_reqs ,
452492                          mca_pml_ucx_persistent_request_t ,
453493                          128 , -1 , 128 );
0 commit comments