@@ -207,10 +207,71 @@ static int progress_callback(void) {
207207 return 0 ;
208208}
209209
210+ static int ucp_context_init (bool enable_mt , int proc_world_size ) {
211+ int ret = OMPI_SUCCESS ;
212+ ucs_status_t status ;
213+ ucp_config_t * config = NULL ;
214+ ucp_params_t context_params ;
215+
216+ status = ucp_config_read ("MPI" , NULL , & config );
217+ if (UCS_OK != status ) {
218+ OSC_UCX_VERBOSE (1 , "ucp_config_read failed: %d" , status );
219+ return OMPI_ERROR ;
220+ }
221+
222+ /* initialize UCP context */
223+ memset (& context_params , 0 , sizeof (context_params ));
224+ context_params .field_mask = UCP_PARAM_FIELD_FEATURES | UCP_PARAM_FIELD_MT_WORKERS_SHARED
225+ | UCP_PARAM_FIELD_ESTIMATED_NUM_EPS | UCP_PARAM_FIELD_REQUEST_INIT
226+ | UCP_PARAM_FIELD_REQUEST_SIZE ;
227+ context_params .features = UCP_FEATURE_RMA | UCP_FEATURE_AMO32 | UCP_FEATURE_AMO64 ;
228+ context_params .mt_workers_shared = (enable_mt ? 1 : 0 );
229+ context_params .estimated_num_eps = proc_world_size ;
230+ context_params .request_init = opal_common_ucx_req_init ;
231+ context_params .request_size = sizeof (opal_common_ucx_request_t );
232+
233+ #if HAVE_DECL_UCP_PARAM_FIELD_ESTIMATED_NUM_PPN
234+ context_params .estimated_num_ppn = opal_process_info .num_local_peers + 1 ;
235+ context_params .field_mask |= UCP_PARAM_FIELD_ESTIMATED_NUM_PPN ;
236+ #endif
237+
238+ status = ucp_init (& context_params , config , & mca_osc_ucx_component .wpool -> ucp_ctx );
239+ if (UCS_OK != status ) {
240+ OSC_UCX_VERBOSE (1 , "ucp_init failed: %d" , status );
241+ ret = OMPI_ERROR ;
242+ }
243+ ucp_config_release (config );
244+
245+ return ret ;
246+ }
247+
210248static int component_init (bool enable_progress_threads , bool enable_mpi_threads ) {
249+ opal_common_ucx_support_level_t support_level ;
250+ int ret = OMPI_SUCCESS ;
251+
211252 mca_osc_ucx_component .enable_mpi_threads = enable_mpi_threads ;
212253 mca_osc_ucx_component .wpool = opal_common_ucx_wpool_allocate ();
213254 opal_common_ucx_mca_register ();
255+
256+ ret = ucp_context_init (enable_mpi_threads , ompi_proc_world_size ());
257+ if (OMPI_ERROR == ret ) {
258+ return OMPI_ERR_NOT_AVAILABLE ;
259+ }
260+
261+ support_level = opal_common_ucx_support_level (mca_osc_ucx_component .wpool -> ucp_ctx );
262+ if (OPAL_COMMON_UCX_SUPPORT_NONE == support_level ) {
263+ ucp_cleanup (mca_osc_ucx_component .wpool -> ucp_ctx );
264+ mca_osc_ucx_component .wpool -> ucp_ctx = NULL ;
265+ return OMPI_ERR_NOT_AVAILABLE ;
266+ }
267+
268+ /*
269+ * Retain priority if we have supported devices and transports.
270+ * Lower priority if we have supported transports, but not supported devices.
271+ */
272+ mca_osc_ucx_component .priority = (support_level == OPAL_COMMON_UCX_SUPPORT_DEVICE ) ?
273+ mca_osc_ucx_component .priority : 19 ;
274+ OSC_UCX_VERBOSE (2 , "returning priority %d" , mca_osc_ucx_component .priority );
214275 return OMPI_SUCCESS ;
215276}
216277
@@ -395,9 +456,7 @@ static int component_select(struct ompi_win_t *win, void **base, size_t size, in
395456 goto select_unlock ;
396457 }
397458
398- ret = opal_common_ucx_wpool_init (mca_osc_ucx_component .wpool ,
399- ompi_proc_world_size (),
400- mca_osc_ucx_component .enable_mpi_threads );
459+ ret = opal_common_ucx_wpool_init (mca_osc_ucx_component .wpool );
401460 if (OMPI_SUCCESS != ret ) {
402461 OSC_UCX_VERBOSE (1 , "opal_common_ucx_wpool_init failed: %d" , ret );
403462 goto select_unlock ;
0 commit comments