diff --git a/src/uct/cuda/base/cuda_iface.c b/src/uct/cuda/base/cuda_iface.c index 678c753a830..8199e4c5890 100644 --- a/src/uct/cuda/base/cuda_iface.c +++ b/src/uct/cuda/base/cuda_iface.c @@ -173,6 +173,97 @@ ucs_status_t uct_cuda_base_iface_event_fd_arm(uct_iface_h tl_iface, return UCS_OK; } +static void uct_cuda_base_stream_flushed_cb(uct_completion_t *self) +{ + uct_cuda_flush_stream_desc_t *desc = + ucs_container_of(self, uct_cuda_flush_stream_desc_t, comp); + + if (--desc->flush_desc->stream_counter == 0) { + uct_invoke_completion(desc->flush_desc->comp, UCS_OK); + ucs_free(desc->flush_desc); + } +} + +/* Flush is done by enqueueing flush events on all active streams and wait + * for them to finish. On each flush event completion, we decrement a shared + * counter and once it hits zero, flush is completed. */ +ucs_status_t +uct_cuda_base_ep_flush(uct_ep_h tl_ep, unsigned flags, uct_completion_t *comp) +{ + uct_base_ep_t UCS_V_UNUSED *ep = ucs_derived_of(tl_ep, uct_base_ep_t); + uct_cuda_iface_t *iface = ucs_derived_of(tl_ep->iface, + uct_cuda_iface_t); + uct_cuda_flush_desc_t *flush_desc; + uct_cuda_flush_stream_desc_t *flush_stream_desc; + uct_cuda_queue_desc_t *q_desc; + uct_cuda_event_desc_t *event_desc; + ucs_queue_head_t *event_queue; + ucs_queue_iter_t iter; + unsigned stream_index; + + if (ucs_queue_is_empty(&iface->active_queue)) { + UCT_TL_EP_STAT_FLUSH(ep); + return UCS_OK; + } + + if (comp == NULL) { + goto out; + } + + /* Allocate base flush descriptor */ + flush_desc = ucs_malloc(sizeof(*flush_desc), "cuda_flush_desc"); + if (flush_desc == NULL) { + return UCS_ERR_NO_MEMORY; + } + + flush_desc->comp = comp; + flush_desc->stream_counter = 0; + + /* For each active stream, init a flush event and enqueue it on the + * stream */ + ucs_queue_for_each(q_desc, &iface->active_queue, queue) { + flush_stream_desc = ucs_mpool_get(&iface->flush_mpool); + if (flush_stream_desc == NULL) { + goto error; + } + + flush_stream_desc->flush_desc = flush_desc; + flush_stream_desc->comp.func = uct_cuda_base_stream_flushed_cb; + flush_stream_desc->comp.count = 1; + flush_stream_desc->super.comp = &flush_stream_desc->comp; + ucs_queue_push(&q_desc->event_queue, &flush_stream_desc->super.queue); + flush_desc->stream_counter++; + } + +out: + UCT_TL_EP_STAT_FLUSH_WAIT(ep); + return UCS_INPROGRESS; + +error: + /* Rollback enqueued items in case of error */ + for (iter = ucs_queue_iter_begin(&iface->active_queue), stream_index = 0; + stream_index < flush_desc->stream_counter; + iter = ucs_queue_iter_next(iter), ++stream_index) { + event_queue = &ucs_queue_iter_elem(q_desc, iter, queue)->event_queue; + event_desc = ucs_queue_tail_elem_non_empty(event_queue, + uct_cuda_event_desc_t, + queue); + + ucs_queue_remove(event_queue, &event_desc->queue); + ucs_mpool_put((uct_cuda_flush_stream_desc_t*)event_desc); + } + + ucs_free(flush_desc); + return UCS_ERR_NO_MEMORY; +} + +static UCS_F_ALWAYS_INLINE int +uct_cuda_base_event_is_flush(uct_cuda_event_desc_t *event) +{ + return (event->comp != NULL) && + (event->comp->func == uct_cuda_base_stream_flushed_cb); +} + static UCS_F_ALWAYS_INLINE unsigned uct_cuda_base_progress_event_queue(uct_cuda_iface_t *iface, ucs_queue_head_t *queue_head, @@ -183,13 +274,18 @@ uct_cuda_base_progress_event_queue(uct_cuda_iface_t *iface, ucs_queue_for_each_extract(cuda_event, queue_head, queue, (count < max_events) && - (cuEventQuery(cuda_event->event) == CUDA_SUCCESS)) { + (ucs_unlikely(uct_cuda_base_event_is_flush( + cuda_event)) || + (cuEventQuery(cuda_event->event) == CUDA_SUCCESS))) { ucs_trace_data("cuda event %p completed", cuda_event); if (cuda_event->comp != NULL) { uct_invoke_completion(cuda_event->comp, UCS_OK); } - iface->ops->complete_event(&iface->super.super, cuda_event); + if (ucs_likely(!uct_cuda_base_event_is_flush(cuda_event))) { + iface->ops->complete_event(&iface->super.super, cuda_event); + } + ucs_mpool_put(cuda_event); count++; } @@ -352,12 +448,23 @@ static void uct_cuda_base_ctx_rsc_destroy(uct_cuda_iface_t *iface, iface->ops->destroy_rsc(&iface->super.super, ctx_rsc); } +static ucs_mpool_ops_t uct_cuda_flush_desc_mpool_ops = { + .chunk_alloc = ucs_mpool_chunk_malloc, + .chunk_release = ucs_mpool_chunk_free, + .obj_init = NULL, + .obj_cleanup = NULL, + .obj_str = NULL +}; + UCS_CLASS_INIT_FUNC(uct_cuda_iface_t, uct_iface_ops_t *tl_ops, uct_iface_internal_ops_t *ops, uct_md_h md, uct_worker_h worker, const uct_iface_params_t *params, const uct_iface_config_t *tl_config, const char *dev_name) { + ucs_mpool_params_t mp_params; + ucs_status_t status; + UCS_CLASS_CALL_SUPER_INIT(uct_base_iface_t, tl_ops, ops, md, worker, params, tl_config UCS_STATS_ARG(params->stats_root) UCS_STATS_ARG(dev_name)); @@ -365,6 +472,17 @@ UCS_CLASS_INIT_FUNC(uct_cuda_iface_t, uct_iface_ops_t *tl_ops, self->eventfd = UCS_ASYNC_EVENTFD_INVALID_FD; kh_init_inplace(cuda_ctx_rscs, &self->ctx_rscs); ucs_queue_head_init(&self->active_queue); + + ucs_mpool_params_reset(&mp_params); + mp_params.elem_size = sizeof(uct_cuda_flush_stream_desc_t); + mp_params.ops = &uct_cuda_flush_desc_mpool_ops; + mp_params.name = "cuda_flush_descriptors"; + + status = ucs_mpool_init(&mp_params, &self->flush_mpool); + if (status != UCS_OK) { + return status; + } + return UCS_OK; } @@ -381,6 +499,7 @@ static UCS_CLASS_CLEANUP_FUNC(uct_cuda_iface_t) kh_destroy_inplace(cuda_ctx_rscs, &self->ctx_rscs); ucs_async_eventfd_destroy(self->eventfd); + ucs_mpool_cleanup(&self->flush_mpool, 1); } UCS_CLASS_DEFINE(uct_cuda_iface_t, uct_base_iface_t); diff --git a/src/uct/cuda/base/cuda_iface.h b/src/uct/cuda/base/cuda_iface.h index ff0a2efea86..71407482a10 100644 --- a/src/uct/cuda/base/cuda_iface.h +++ b/src/uct/cuda/base/cuda_iface.h @@ -95,6 +95,23 @@ typedef struct { } uct_cuda_event_desc_t; +/* Base flush descriptor */ +typedef struct { + /* How many streams are currently active */ + uint32_t stream_counter; + uct_completion_t *comp; +} uct_cuda_flush_desc_t; + + +/* Stream Flush descriptor */ +typedef struct { + uct_cuda_event_desc_t super; + /* Pointer to base flush descriptor */ + uct_cuda_flush_desc_t *flush_desc; + uct_completion_t comp; +} uct_cuda_flush_stream_desc_t; + + typedef struct { /* CUDA context handle */ CUcontext ctx; @@ -130,6 +147,8 @@ typedef struct { /* list of queues which require progress */ ucs_queue_head_t active_queue; uct_cuda_iface_ops_t *ops; + /* Pool for flush events */ + ucs_mpool_t flush_mpool; struct { unsigned max_events; @@ -148,6 +167,8 @@ unsigned uct_cuda_base_iface_progress(uct_iface_h tl_iface); ucs_status_t uct_cuda_base_iface_flush(uct_iface_h tl_iface, unsigned flags, uct_completion_t *comp); +ucs_status_t +uct_cuda_base_ep_flush(uct_ep_h tl_ep, unsigned flags, uct_completion_t *comp); ucs_status_t uct_cuda_base_query_devices_common( uct_md_h md, uct_device_type_t dev_type, diff --git a/src/uct/cuda/cuda_copy/cuda_copy_iface.c b/src/uct/cuda/cuda_copy/cuda_copy_iface.c index f2c4bcb99d1..e2d8bd51b30 100644 --- a/src/uct/cuda/cuda_copy/cuda_copy_iface.c +++ b/src/uct/cuda/cuda_copy/cuda_copy_iface.c @@ -158,7 +158,7 @@ static uct_iface_ops_t uct_cuda_copy_iface_ops = { .ep_put_zcopy = uct_cuda_copy_ep_put_zcopy, .ep_pending_add = (uct_ep_pending_add_func_t)ucs_empty_function_return_busy, .ep_pending_purge = (uct_ep_pending_purge_func_t)ucs_empty_function, - .ep_flush = uct_base_ep_flush, + .ep_flush = uct_cuda_base_ep_flush, .ep_fence = uct_base_ep_fence, .ep_create = UCS_CLASS_NEW_FUNC_NAME(uct_cuda_copy_ep_t), .ep_destroy = UCS_CLASS_DELETE_FUNC_NAME(uct_cuda_copy_ep_t), diff --git a/src/uct/cuda/cuda_ipc/cuda_ipc_iface.c b/src/uct/cuda/cuda_ipc/cuda_ipc_iface.c index 116f3791f35..75a4e1af786 100644 --- a/src/uct/cuda/cuda_ipc/cuda_ipc_iface.c +++ b/src/uct/cuda/cuda_ipc/cuda_ipc_iface.c @@ -336,7 +336,7 @@ static uct_iface_ops_t uct_cuda_ipc_iface_ops = { .ep_put_zcopy = uct_cuda_ipc_ep_put_zcopy, .ep_pending_add = (uct_ep_pending_add_func_t)ucs_empty_function_return_busy, .ep_pending_purge = (uct_ep_pending_purge_func_t)ucs_empty_function, - .ep_flush = uct_base_ep_flush, + .ep_flush = uct_cuda_base_ep_flush, .ep_fence = uct_base_ep_fence, .ep_check = (uct_ep_check_func_t)ucs_empty_function_return_unsupported, .ep_create = UCS_CLASS_NEW_FUNC_NAME(uct_cuda_ipc_ep_t),