@@ -60,7 +60,7 @@ static inline void ompi_osc_ucx_handle_incoming_post(ompi_osc_ucx_module_t *modu
6060
6161int ompi_osc_ucx_fence (int assert , struct ompi_win_t * win ) {
6262 ompi_osc_ucx_module_t * module = (ompi_osc_ucx_module_t * ) win -> w_osc_module ;
63- int ret ;
63+ int ret = OMPI_SUCCESS ;
6464
6565 if (module -> epoch_type .access != NONE_EPOCH &&
6666 module -> epoch_type .access != FENCE_EPOCH ) {
@@ -74,16 +74,12 @@ int ompi_osc_ucx_fence(int assert, struct ompi_win_t *win) {
7474 }
7575
7676 if (!(assert & MPI_MODE_NOPRECEDE )) {
77- ret = opal_common_ucx_worker_flush ( mca_osc_ucx_component . ucp_worker );
77+ ret = opal_common_ucx_wpmem_flush ( module -> mem , OPAL_COMMON_UCX_SCOPE_WORKER , 0 /*ignore*/ );
7878 if (ret != OMPI_SUCCESS ) {
7979 return ret ;
8080 }
8181 }
8282
83- module -> global_ops_num = 0 ;
84- memset (module -> per_target_ops_nums , 0 ,
85- sizeof (int ) * ompi_comm_size (module -> comm ));
86-
8783 return module -> comm -> c_coll -> coll_barrier (module -> comm ,
8884 module -> comm -> c_coll -> coll_barrier_module );
8985}
@@ -147,7 +143,7 @@ int ompi_osc_ucx_start(struct ompi_group_t *group, int assert, struct ompi_win_t
147143
148144 ompi_osc_ucx_handle_incoming_post (module , & (module -> state .post_state [i ]), ranks_in_win_grp , size );
149145 }
150- ucp_worker_progress (mca_osc_ucx_component .ucp_worker );
146+ opal_common_ucx_wpool_progress (mca_osc_ucx_component .wpool );
151147 }
152148
153149 module -> post_count = 0 ;
@@ -163,7 +159,6 @@ int ompi_osc_ucx_start(struct ompi_group_t *group, int assert, struct ompi_win_t
163159
164160int ompi_osc_ucx_complete (struct ompi_win_t * win ) {
165161 ompi_osc_ucx_module_t * module = (ompi_osc_ucx_module_t * ) win -> w_osc_module ;
166- ucs_status_t status ;
167162 int i , size ;
168163 int ret = OMPI_SUCCESS ;
169164
@@ -173,29 +168,26 @@ int ompi_osc_ucx_complete(struct ompi_win_t *win) {
173168
174169 module -> epoch_type .access = NONE_EPOCH ;
175170
176- ret = opal_common_ucx_worker_flush ( mca_osc_ucx_component . ucp_worker );
171+ ret = opal_common_ucx_wpmem_flush ( module -> mem , OPAL_COMMON_UCX_SCOPE_WORKER , 0 /*ignore*/ );
177172 if (ret != OMPI_SUCCESS ) {
178173 return ret ;
179174 }
180- module -> global_ops_num = 0 ;
181- memset (module -> per_target_ops_nums , 0 ,
182- sizeof (int ) * ompi_comm_size (module -> comm ));
183175
184176 size = ompi_group_size (module -> start_group );
185177 for (i = 0 ; i < size ; i ++ ) {
186- uint64_t remote_addr = (module -> state_info_array )[module -> start_grp_ranks [i ]].addr + OSC_UCX_STATE_COMPLETE_COUNT_OFFSET ; /* write to state.complete_count on remote side */
187- ucp_rkey_h rkey = (module -> state_info_array )[module -> start_grp_ranks [i ]].rkey ;
188- ucp_ep_h ep = OSC_UCX_GET_EP (module -> comm , module -> start_grp_ranks [i ]);
189-
190- status = ucp_atomic_post (ep , UCP_ATOMIC_POST_OP_ADD , 1 ,
191- sizeof (uint64_t ), remote_addr , rkey );
192- if (status != UCS_OK ) {
193- OSC_UCX_VERBOSE (1 , "ucp_atomic_post failed: %d" , status );
178+ uint64_t remote_addr = module -> state_addrs [module -> start_grp_ranks [i ]] + OSC_UCX_STATE_COMPLETE_COUNT_OFFSET ; // write to state.complete_count on remote side
179+
180+ ret = opal_common_ucx_wpmem_post (module -> mem , UCP_ATOMIC_POST_OP_ADD ,
181+ 1 , module -> start_grp_ranks [i ], sizeof (uint64_t ),
182+ remote_addr );
183+ if (ret != OMPI_SUCCESS ) {
184+ OSC_UCX_VERBOSE (1 , "opal_common_ucx_mem_post failed: %d" , ret );
194185 }
195186
196- ret = opal_common_ucx_ep_flush (ep , mca_osc_ucx_component .ucp_worker );
197- if (OMPI_SUCCESS != ret ) {
198- OSC_UCX_VERBOSE (1 , "opal_common_ucx_ep_flush failed: %d" , ret );
187+ ret = opal_common_ucx_wpmem_flush (module -> mem , OPAL_COMMON_UCX_SCOPE_EP ,
188+ module -> start_grp_ranks [i ]);
189+ if (ret != OMPI_SUCCESS ) {
190+ return ret ;
199191 }
200192 }
201193
@@ -243,25 +235,29 @@ int ompi_osc_ucx_post(struct ompi_group_t *group, int assert, struct ompi_win_t
243235 }
244236
245237 for (i = 0 ; i < size ; i ++ ) {
246- uint64_t remote_addr = (module -> state_info_array )[ranks_in_win_grp [i ]].addr + OSC_UCX_STATE_POST_INDEX_OFFSET ; /* write to state.post_index on remote side */
247- ucp_rkey_h rkey = (module -> state_info_array )[ranks_in_win_grp [i ]].rkey ;
248- ucp_ep_h ep = OSC_UCX_GET_EP (module -> comm , ranks_in_win_grp [i ]);
238+ uint64_t remote_addr = module -> state_addrs [ranks_in_win_grp [i ]] + OSC_UCX_STATE_POST_INDEX_OFFSET ; // write to state.post_index on remote side
249239 uint64_t curr_idx = 0 , result = 0 ;
250240
251241 /* do fop first to get an post index */
252- opal_common_ucx_atomic_fetch (ep , UCP_ATOMIC_FETCH_OP_FADD , 1 ,
253- & result , sizeof (result ),
254- remote_addr , rkey , mca_osc_ucx_component .ucp_worker );
242+ ret = opal_common_ucx_wpmem_fetch (module -> mem , UCP_ATOMIC_FETCH_OP_FADD ,
243+ 1 , ranks_in_win_grp [i ], & result ,
244+ sizeof (result ), remote_addr );
245+ if (ret != OMPI_SUCCESS ) {
246+ return OMPI_ERROR ;
247+ }
255248
256249 curr_idx = result & (OMPI_OSC_UCX_POST_PEER_MAX - 1 );
257250
258- remote_addr = ( module -> state_info_array ) [ranks_in_win_grp [i ]]. addr + OSC_UCX_STATE_POST_STATE_OFFSET + sizeof (uint64_t ) * curr_idx ;
251+ remote_addr = module -> state_addrs [ranks_in_win_grp [i ]] + OSC_UCX_STATE_POST_STATE_OFFSET + sizeof (uint64_t ) * curr_idx ;
259252
260253 /* do cas to send post message */
261254 do {
262- opal_common_ucx_atomic_cswap (ep , 0 , (uint64_t )myrank + 1 , & result ,
263- sizeof (result ), remote_addr , rkey ,
264- mca_osc_ucx_component .ucp_worker );
255+ ret = opal_common_ucx_wpmem_cmpswp (module -> mem , 0 , result ,
256+ myrank + 1 , & result , sizeof (result ),
257+ remote_addr );
258+ if (ret != OMPI_SUCCESS ) {
259+ return OMPI_ERROR ;
260+ }
265261
266262 if (result == 0 )
267263 break ;
@@ -302,7 +298,7 @@ int ompi_osc_ucx_wait(struct ompi_win_t *win) {
302298
303299 while (module -> state .complete_count != (uint64_t )size ) {
304300 /* not sure if this is required */
305- ucp_worker_progress (mca_osc_ucx_component .ucp_worker );
301+ opal_common_ucx_wpool_progress (mca_osc_ucx_component .wpool );
306302 }
307303
308304 module -> state .complete_count = 0 ;
0 commit comments