Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 121 additions & 2 deletions src/uct/cuda/base/cuda_iface.c
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,97 @@ ucs_status_t uct_cuda_base_iface_event_fd_arm(uct_iface_h tl_iface,
return UCS_OK;
}

static void uct_cuda_base_stream_flushed_cb(uct_completion_t *self)
{
uct_cuda_flush_stream_desc_t *desc =
ucs_container_of(self, uct_cuda_flush_stream_desc_t, comp);

if (--desc->flush_desc->stream_counter == 0) {
uct_invoke_completion(desc->flush_desc->comp, UCS_OK);
ucs_free(desc->flush_desc);
}
}

/* Flush is done by enqueueing flush events on all active streams and wait
* for them to finish. On each flush event completion, we decrement a shared
* counter and once it hits zero, flush is completed. */
ucs_status_t
uct_cuda_base_ep_flush(uct_ep_h tl_ep, unsigned flags, uct_completion_t *comp)
{
uct_base_ep_t UCS_V_UNUSED *ep = ucs_derived_of(tl_ep, uct_base_ep_t);
uct_cuda_iface_t *iface = ucs_derived_of(tl_ep->iface,
uct_cuda_iface_t);
uct_cuda_flush_desc_t *flush_desc;
uct_cuda_flush_stream_desc_t *flush_stream_desc;
uct_cuda_queue_desc_t *q_desc;
uct_cuda_event_desc_t *event_desc;
ucs_queue_head_t *event_queue;
ucs_queue_iter_t iter;
unsigned stream_index;

if (ucs_queue_is_empty(&iface->active_queue)) {
UCT_TL_EP_STAT_FLUSH(ep);
return UCS_OK;
}

if (comp == NULL) {
goto out;
}

/* Allocate base flush descriptor */
flush_desc = ucs_malloc(sizeof(*flush_desc), "cuda_flush_desc");
if (flush_desc == NULL) {
return UCS_ERR_NO_MEMORY;
}

flush_desc->comp = comp;
flush_desc->stream_counter = 0;

/* For each active stream, init a flush event and enqueue it on the
* stream */
ucs_queue_for_each(q_desc, &iface->active_queue, queue) {
flush_stream_desc = ucs_mpool_get(&iface->flush_mpool);
if (flush_stream_desc == NULL) {
goto error;
}

flush_stream_desc->flush_desc = flush_desc;
flush_stream_desc->comp.func = uct_cuda_base_stream_flushed_cb;
flush_stream_desc->comp.count = 1;
flush_stream_desc->super.comp = &flush_stream_desc->comp;
ucs_queue_push(&q_desc->event_queue, &flush_stream_desc->super.queue);
flush_desc->stream_counter++;
}

out:
UCT_TL_EP_STAT_FLUSH_WAIT(ep);
return UCS_INPROGRESS;

error:
/* Rollback enqueued items in case of error */
for (iter = ucs_queue_iter_begin(&iface->active_queue), stream_index = 0;
stream_index < flush_desc->stream_counter;
iter = ucs_queue_iter_next(iter), ++stream_index) {
event_queue = &ucs_queue_iter_elem(q_desc, iter, queue)->event_queue;
event_desc = ucs_queue_tail_elem_non_empty(event_queue,
uct_cuda_event_desc_t,
queue);

ucs_queue_remove(event_queue, &event_desc->queue);
ucs_mpool_put((uct_cuda_flush_stream_desc_t*)event_desc);
}

ucs_free(flush_desc);
return UCS_ERR_NO_MEMORY;
}

static UCS_F_ALWAYS_INLINE int
uct_cuda_base_event_is_flush(uct_cuda_event_desc_t *event)
{
return (event->comp != NULL) &&
(event->comp->func == uct_cuda_base_stream_flushed_cb);
}

static UCS_F_ALWAYS_INLINE unsigned
uct_cuda_base_progress_event_queue(uct_cuda_iface_t *iface,
ucs_queue_head_t *queue_head,
Expand All @@ -183,13 +274,18 @@ uct_cuda_base_progress_event_queue(uct_cuda_iface_t *iface,

ucs_queue_for_each_extract(cuda_event, queue_head, queue,
(count < max_events) &&
(cuEventQuery(cuda_event->event) == CUDA_SUCCESS)) {
(ucs_unlikely(uct_cuda_base_event_is_flush(
cuda_event)) ||
(cuEventQuery(cuda_event->event) == CUDA_SUCCESS))) {
ucs_trace_data("cuda event %p completed", cuda_event);
if (cuda_event->comp != NULL) {
uct_invoke_completion(cuda_event->comp, UCS_OK);
}

iface->ops->complete_event(&iface->super.super, cuda_event);
if (ucs_likely(!uct_cuda_base_event_is_flush(cuda_event))) {
iface->ops->complete_event(&iface->super.super, cuda_event);
}

ucs_mpool_put(cuda_event);
count++;
}
Expand Down Expand Up @@ -352,19 +448,41 @@ static void uct_cuda_base_ctx_rsc_destroy(uct_cuda_iface_t *iface,
iface->ops->destroy_rsc(&iface->super.super, ctx_rsc);
}

static ucs_mpool_ops_t uct_cuda_flush_desc_mpool_ops = {
.chunk_alloc = ucs_mpool_chunk_malloc,
.chunk_release = ucs_mpool_chunk_free,
.obj_init = NULL,
.obj_cleanup = NULL,
.obj_str = NULL
};

UCS_CLASS_INIT_FUNC(uct_cuda_iface_t, uct_iface_ops_t *tl_ops,
uct_iface_internal_ops_t *ops, uct_md_h md,
uct_worker_h worker, const uct_iface_params_t *params,
const uct_iface_config_t *tl_config,
const char *dev_name)
{
ucs_mpool_params_t mp_params;
ucs_status_t status;

UCS_CLASS_CALL_SUPER_INIT(uct_base_iface_t, tl_ops, ops, md, worker, params,
tl_config UCS_STATS_ARG(params->stats_root)
UCS_STATS_ARG(dev_name));

self->eventfd = UCS_ASYNC_EVENTFD_INVALID_FD;
kh_init_inplace(cuda_ctx_rscs, &self->ctx_rscs);
ucs_queue_head_init(&self->active_queue);

ucs_mpool_params_reset(&mp_params);
mp_params.elem_size = sizeof(uct_cuda_flush_stream_desc_t);
mp_params.ops = &uct_cuda_flush_desc_mpool_ops;
mp_params.name = "cuda_flush_descriptors";

status = ucs_mpool_init(&mp_params, &self->flush_mpool);
if (status != UCS_OK) {
return status;
}

return UCS_OK;
}

Expand All @@ -381,6 +499,7 @@ static UCS_CLASS_CLEANUP_FUNC(uct_cuda_iface_t)

kh_destroy_inplace(cuda_ctx_rscs, &self->ctx_rscs);
ucs_async_eventfd_destroy(self->eventfd);
ucs_mpool_cleanup(&self->flush_mpool, 1);
}

UCS_CLASS_DEFINE(uct_cuda_iface_t, uct_base_iface_t);
21 changes: 21 additions & 0 deletions src/uct/cuda/base/cuda_iface.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,23 @@ typedef struct {
} uct_cuda_event_desc_t;


/* Base flush descriptor */
typedef struct {
/* How many streams are currently active */
uint32_t stream_counter;
uct_completion_t *comp;
} uct_cuda_flush_desc_t;


/* Stream Flush descriptor */
typedef struct {
uct_cuda_event_desc_t super;
/* Pointer to base flush descriptor */
uct_cuda_flush_desc_t *flush_desc;
uct_completion_t comp;
} uct_cuda_flush_stream_desc_t;


typedef struct {
/* CUDA context handle */
CUcontext ctx;
Expand Down Expand Up @@ -130,6 +147,8 @@ typedef struct {
/* list of queues which require progress */
ucs_queue_head_t active_queue;
uct_cuda_iface_ops_t *ops;
/* Pool for flush events */
ucs_mpool_t flush_mpool;

struct {
unsigned max_events;
Expand All @@ -148,6 +167,8 @@ unsigned uct_cuda_base_iface_progress(uct_iface_h tl_iface);
ucs_status_t uct_cuda_base_iface_flush(uct_iface_h tl_iface, unsigned flags,
uct_completion_t *comp);

ucs_status_t
uct_cuda_base_ep_flush(uct_ep_h tl_ep, unsigned flags, uct_completion_t *comp);
ucs_status_t
uct_cuda_base_query_devices_common(
uct_md_h md, uct_device_type_t dev_type,
Expand Down
2 changes: 1 addition & 1 deletion src/uct/cuda/cuda_copy/cuda_copy_iface.c
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ static uct_iface_ops_t uct_cuda_copy_iface_ops = {
.ep_put_zcopy = uct_cuda_copy_ep_put_zcopy,
.ep_pending_add = (uct_ep_pending_add_func_t)ucs_empty_function_return_busy,
.ep_pending_purge = (uct_ep_pending_purge_func_t)ucs_empty_function,
.ep_flush = uct_base_ep_flush,
.ep_flush = uct_cuda_base_ep_flush,
.ep_fence = uct_base_ep_fence,
.ep_create = UCS_CLASS_NEW_FUNC_NAME(uct_cuda_copy_ep_t),
.ep_destroy = UCS_CLASS_DELETE_FUNC_NAME(uct_cuda_copy_ep_t),
Expand Down
2 changes: 1 addition & 1 deletion src/uct/cuda/cuda_ipc/cuda_ipc_iface.c
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,7 @@ static uct_iface_ops_t uct_cuda_ipc_iface_ops = {
.ep_put_zcopy = uct_cuda_ipc_ep_put_zcopy,
.ep_pending_add = (uct_ep_pending_add_func_t)ucs_empty_function_return_busy,
.ep_pending_purge = (uct_ep_pending_purge_func_t)ucs_empty_function,
.ep_flush = uct_base_ep_flush,
.ep_flush = uct_cuda_base_ep_flush,
.ep_fence = uct_base_ep_fence,
.ep_check = (uct_ep_check_func_t)ucs_empty_function_return_unsupported,
.ep_create = UCS_CLASS_NEW_FUNC_NAME(uct_cuda_ipc_ep_t),
Expand Down
Loading