@@ -18,7 +18,7 @@ categories :
1818=== END_MPI_T_CVAR_INFO_BLOCK ===
1919*/
2020
21- static void request_init_callback ( void * request ) ;
21+ static bool ucx_initialized = false ;
2222
2323static void request_init_callback (void * request )
2424{
@@ -28,8 +28,6 @@ static void request_init_callback(void *request)
2828
2929}
3030
31- static void flush_all (void );
32-
3331int MPIDI_UCX_init_worker (int vci )
3432{
3533 int mpi_errno = MPI_SUCCESS ;
@@ -143,68 +141,6 @@ static int initial_address_exchange(void)
143141 goto fn_exit ;
144142}
145143
146- int MPIDI_UCX_all_vcis_address_exchange (void )
147- {
148- int mpi_errno = MPI_SUCCESS ;
149-
150- int size = MPIR_Process .size ;
151- int rank = MPIR_Process .rank ;
152- int num_vcis = MPIDI_UCX_global .num_vcis ;
153-
154- /* ucx address lengths are non-uniform, use MPID_MAX_BC_SIZE */
155- size_t name_len = MPID_MAX_BC_SIZE ;
156-
157- int my_len = num_vcis * name_len ;
158- char * all_names = MPL_malloc (size * my_len , MPL_MEM_ADDRESS );
159- MPIR_Assert (all_names );
160-
161- char * my_names = all_names + rank * my_len ;
162-
163- /* put in my addrnames */
164- for (int i = 0 ; i < num_vcis ; i ++ ) {
165- char * vci_addrname = my_names + i * name_len ;
166- memcpy (vci_addrname , MPIDI_UCX_global .ctx [i ].if_address ,
167- MPIDI_UCX_global .ctx [i ].addrname_len );
168- }
169- /* Allgather */
170- MPIR_Comm * comm = MPIR_Process .comm_world ;
171- mpi_errno = MPIR_Allgather_allcomm_auto (MPI_IN_PLACE , 0 , MPIR_BYTE_INTERNAL ,
172- all_names , my_len , MPIR_BYTE_INTERNAL , comm ,
173- MPIR_ERR_NONE );
174- MPIR_ERR_CHECK (mpi_errno );
175-
176- /* insert the addresses */
177- ucp_ep_params_t ep_params ;
178- for (int vci_local = 0 ; vci_local < num_vcis ; vci_local ++ ) {
179- for (int r = 0 ; r < size ; r ++ ) {
180- MPIDI_UCX_addr_t * av = & MPIDI_UCX_AV (& MPIDIU_get_av (0 , r ));
181- for (int vci_remote = 0 ; vci_remote < num_vcis ; vci_remote ++ ) {
182- if (vci_local == 0 && vci_remote == 0 ) {
183- /* don't overwrite existing addr, or bad things will happen */
184- continue ;
185- }
186- int idx = r * num_vcis + vci_remote ;
187- ep_params .field_mask = UCP_EP_PARAM_FIELD_REMOTE_ADDRESS ;
188- ep_params .address = (ucp_address_t * ) (all_names + idx * name_len );
189-
190- ucs_status_t ucx_status ;
191- ucx_status = ucp_ep_create (MPIDI_UCX_global .ctx [vci_local ].worker ,
192- & ep_params , & av -> dest [vci_local ][vci_remote ]);
193- MPIDI_UCX_CHK_STATUS (ucx_status );
194- }
195- }
196- }
197-
198- /* Flush all pending wireup operations or it may interfere with RMA flush_ops count. */
199- flush_all ();
200-
201- fn_exit :
202- MPL_free (all_names );
203- return mpi_errno ;
204- fn_fail :
205- goto fn_exit ;
206- }
207-
208144int MPIDI_UCX_init_local (int * tag_bits )
209145{
210146 int mpi_errno = MPI_SUCCESS ;
@@ -234,7 +170,7 @@ int MPIDI_UCX_init_local(int *tag_bits)
234170 UCP_PARAM_FIELD_REQUEST_SIZE |
235171 UCP_PARAM_FIELD_ESTIMATED_NUM_EPS | UCP_PARAM_FIELD_REQUEST_INIT ;
236172
237- if (MPIDI_UCX_global . num_vcis > 1 ) {
173+ if (MPICH_IS_THREADED ) {
238174 ucp_params .mt_workers_shared = 1 ;
239175 ucp_params .field_mask |= UCP_PARAM_FIELD_MT_WORKERS_SHARED ;
240176 }
@@ -277,6 +213,8 @@ int MPIDI_UCX_init_world(void)
277213 mpi_errno = initial_address_exchange ();
278214 MPIR_ERR_CHECK (mpi_errno );
279215
216+ ucx_initialized = true;
217+
280218 fn_exit :
281219 return mpi_errno ;
282220 fn_fail :
@@ -286,53 +224,23 @@ int MPIDI_UCX_init_world(void)
286224 goto fn_exit ;
287225}
288226
289- /* static functions for MPIDI_UCX_post_init */
290- static void flush_cb (void * request , ucs_status_t status )
291- {
292- }
293-
294- static void flush_all (void )
295- {
296- void * reqs [MPIDI_CH4_MAX_VCIS ];
297- for (int vci = 0 ; vci < MPIDI_UCX_global .num_vcis ; vci ++ ) {
298- reqs [vci ] = ucp_worker_flush_nb (MPIDI_UCX_global .ctx [vci ].worker , 0 , & flush_cb );
299- }
300- for (int vci = 0 ; vci < MPIDI_UCX_global .num_vcis ; vci ++ ) {
301- if (reqs [vci ] == NULL ) {
302- continue ;
303- } else if (UCS_PTR_IS_ERR (reqs [vci ])) {
304- continue ;
305- } else {
306- ucs_status_t status ;
307- do {
308- MPID_Progress_test (NULL );
309- status = ucp_request_check_status (reqs [vci ]);
310- } while (status == UCS_INPROGRESS );
311- ucp_request_release (reqs [vci ]);
312- }
313- }
314- }
315-
316227int MPIDI_UCX_post_init (void )
317228{
318229 int mpi_errno = MPI_SUCCESS ;
319230
320- MPIDI_global .is_initialized = 1 ;
321-
322231 return mpi_errno ;
323232}
324233
325234int MPIDI_UCX_mpi_finalize_hook (void )
326235{
327236 int mpi_errno = MPI_SUCCESS ;
328237
329- if (!MPIDI_global .is_initialized ) {
330- /* Nothing to do */
331- return mpi_errno ;
332- }
333-
334238 ucs_status_ptr_t ucp_request ;
335- ucs_status_ptr_t * pending ;
239+ ucs_status_ptr_t * pending = NULL ;
240+
241+ if (!ucx_initialized ) {
242+ goto fn_exit ;
243+ }
336244
337245 int n = MPIDI_UCX_global .num_vcis ;
338246 pending = MPL_malloc (sizeof (ucs_status_ptr_t ) * MPIR_Process .size * n * n , MPL_MEM_OTHER );
0 commit comments