@@ -173,6 +173,96 @@ ucs_status_t uct_cuda_base_iface_event_fd_arm(uct_iface_h tl_iface,
173173 return UCS_OK ;
174174}
175175
176+ static void uct_cuda_base_stream_flushed_cb (uct_completion_t * self )
177+ {
178+ uct_cuda_flush_stream_desc_t * desc =
179+ ucs_container_of (self , uct_cuda_flush_stream_desc_t , comp );
180+
181+ if (-- desc -> flush_desc -> stream_counter == 0 ) {
182+ uct_invoke_completion (desc -> flush_desc -> comp , UCS_OK );
183+ ucs_free (desc -> flush_desc );
184+ }
185+ }
186+
187+ /* Flush is done by enqueueing flush events on all active streams and wait
188+ * for them to finish. On each flush event completion, we decrement a shared
189+ * counter and once it hits zero, flush is completed. */
190+ ucs_status_t
191+ uct_cuda_base_ep_flush (uct_ep_h tl_ep , unsigned flags , uct_completion_t * comp )
192+ {
193+ uct_base_ep_t * ep = ucs_derived_of (tl_ep , uct_base_ep_t );
194+ uct_cuda_iface_t * iface = ucs_derived_of (tl_ep -> iface , uct_cuda_iface_t );
195+ uct_cuda_flush_desc_t * flush_desc ;
196+ uct_cuda_flush_stream_desc_t * flush_stream_desc ;
197+ uct_cuda_queue_desc_t * q_desc ;
198+ uct_cuda_event_desc_t * event_desc ;
199+ ucs_queue_head_t * event_queue ;
200+ ucs_queue_iter_t iter ;
201+ unsigned stream_index ;
202+
203+ if (ucs_queue_is_empty (& iface -> active_queue )) {
204+ UCT_TL_EP_STAT_FLUSH (ep );
205+ return UCS_OK ;
206+ }
207+
208+ if (comp == NULL ) {
209+ goto out ;
210+ }
211+
212+ /* Allocate base flush descriptor */
213+ flush_desc = ucs_malloc (sizeof (* flush_desc ), "cuda_flush_desc" );
214+ if (flush_desc == NULL ) {
215+ return UCS_ERR_NO_MEMORY ;
216+ }
217+
218+ flush_desc -> comp = comp ;
219+ flush_desc -> stream_counter = 0 ;
220+
221+ /* For each active stream, init a flush event and enqueue it on the
222+ * stream */
223+ ucs_queue_for_each (q_desc , & iface -> active_queue , queue ) {
224+ flush_stream_desc = ucs_mpool_get (& iface -> flush_mpool );
225+ if (flush_stream_desc == NULL ) {
226+ goto error ;
227+ }
228+
229+ flush_stream_desc -> flush_desc = flush_desc ;
230+ flush_stream_desc -> comp .func = uct_cuda_base_stream_flushed_cb ;
231+ flush_stream_desc -> comp .count = 1 ;
232+ flush_stream_desc -> super .comp = & flush_stream_desc -> comp ;
233+ ucs_queue_push (& q_desc -> event_queue , & flush_stream_desc -> super .queue );
234+ flush_desc -> stream_counter ++ ;
235+ }
236+
237+ out :
238+ UCT_TL_EP_STAT_FLUSH_WAIT (ep );
239+ return UCS_INPROGRESS ;
240+
241+ error :
242+ /* Rollback enqueued items in case of error */
243+ for (iter = ucs_queue_iter_begin (& iface -> active_queue ), stream_index = 0 ;
244+ stream_index < flush_desc -> stream_counter ;
245+ iter = ucs_queue_iter_next (iter ), ++ stream_index ) {
246+ event_queue = & ucs_queue_iter_elem (q_desc , iter , queue )-> event_queue ;
247+ event_desc = ucs_queue_tail_elem_non_empty (event_queue ,
248+ uct_cuda_event_desc_t ,
249+ queue );
250+
251+ ucs_queue_remove (event_queue , & event_desc -> queue );
252+ ucs_mpool_put ((uct_cuda_flush_stream_desc_t * )event_desc );
253+ }
254+
255+ ucs_free (flush_desc );
256+ return UCS_ERR_NO_MEMORY ;
257+ }
258+
259+ static UCS_F_ALWAYS_INLINE int
260+ uct_cuda_base_event_is_flush (uct_cuda_event_desc_t * event )
261+ {
262+ return (event -> comp != NULL ) &&
263+ (event -> comp -> func == uct_cuda_base_stream_flushed_cb );
264+ }
265+
176266static UCS_F_ALWAYS_INLINE unsigned
177267uct_cuda_base_progress_event_queue (uct_cuda_iface_t * iface ,
178268 ucs_queue_head_t * queue_head ,
@@ -183,13 +273,18 @@ uct_cuda_base_progress_event_queue(uct_cuda_iface_t *iface,
183273
184274 ucs_queue_for_each_extract (cuda_event , queue_head , queue ,
185275 (count < max_events ) &&
186- (cuEventQuery (cuda_event -> event ) == CUDA_SUCCESS )) {
276+ (ucs_likely (uct_cuda_base_event_is_flush (
277+ cuda_event )) ||
278+ (cuEventQuery (cuda_event -> event ) == CUDA_SUCCESS ))) {
187279 ucs_trace_data ("cuda event %p completed" , cuda_event );
188280 if (cuda_event -> comp != NULL ) {
189281 uct_invoke_completion (cuda_event -> comp , UCS_OK );
190282 }
191283
192- iface -> ops -> complete_event (& iface -> super .super , cuda_event );
284+ if (ucs_likely (!uct_cuda_base_event_is_flush (cuda_event ))) {
285+ iface -> ops -> complete_event (& iface -> super .super , cuda_event );
286+ }
287+
193288 ucs_mpool_put (cuda_event );
194289 count ++ ;
195290 }
@@ -352,19 +447,41 @@ static void uct_cuda_base_ctx_rsc_destroy(uct_cuda_iface_t *iface,
352447 iface -> ops -> destroy_rsc (& iface -> super .super , ctx_rsc );
353448}
354449
450+ static ucs_mpool_ops_t uct_cuda_flush_desc_mpool_ops = {
451+ .chunk_alloc = ucs_mpool_chunk_malloc ,
452+ .chunk_release = ucs_mpool_chunk_free ,
453+ .obj_init = NULL ,
454+ .obj_cleanup = NULL ,
455+ .obj_str = NULL
456+ };
457+
355458UCS_CLASS_INIT_FUNC (uct_cuda_iface_t , uct_iface_ops_t * tl_ops ,
356459 uct_iface_internal_ops_t * ops , uct_md_h md ,
357460 uct_worker_h worker , const uct_iface_params_t * params ,
358461 const uct_iface_config_t * tl_config ,
359462 const char * dev_name )
360463{
464+ ucs_mpool_params_t mp_params ;
465+ ucs_status_t status ;
466+
361467 UCS_CLASS_CALL_SUPER_INIT (uct_base_iface_t , tl_ops , ops , md , worker , params ,
362468 tl_config UCS_STATS_ARG (params -> stats_root )
363469 UCS_STATS_ARG (dev_name ));
364470
365471 self -> eventfd = UCS_ASYNC_EVENTFD_INVALID_FD ;
366472 kh_init_inplace (cuda_ctx_rscs , & self -> ctx_rscs );
367473 ucs_queue_head_init (& self -> active_queue );
474+
475+ ucs_mpool_params_reset (& mp_params );
476+ mp_params .elem_size = sizeof (uct_cuda_flush_stream_desc_t );
477+ mp_params .ops = & uct_cuda_flush_desc_mpool_ops ;
478+ mp_params .name = "cuda_flush_descriptors" ;
479+
480+ status = ucs_mpool_init (& mp_params , & self -> flush_mpool );
481+ if (status != UCS_OK ) {
482+ return status ;
483+ }
484+
368485 return UCS_OK ;
369486}
370487
@@ -381,6 +498,7 @@ static UCS_CLASS_CLEANUP_FUNC(uct_cuda_iface_t)
381498
382499 kh_destroy_inplace (cuda_ctx_rscs , & self -> ctx_rscs );
383500 ucs_async_eventfd_destroy (self -> eventfd );
501+ ucs_mpool_cleanup (& self -> flush_mpool , 1 );
384502}
385503
386504UCS_CLASS_DEFINE (uct_cuda_iface_t , uct_base_iface_t );
0 commit comments