Skip to content

Commit 0467037

Browse files
committed
UCT/CUDA: Add support for flush_ep
1 parent 4b3747d commit 0467037

File tree

4 files changed

+143
-4
lines changed

4 files changed

+143
-4
lines changed

src/uct/cuda/base/cuda_iface.c

Lines changed: 120 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,96 @@ ucs_status_t uct_cuda_base_iface_event_fd_arm(uct_iface_h tl_iface,
173173
return UCS_OK;
174174
}
175175

176+
static void uct_cuda_base_stream_flushed_cb(uct_completion_t *self)
177+
{
178+
uct_cuda_flush_stream_desc_t *desc =
179+
ucs_container_of(self, uct_cuda_flush_stream_desc_t, comp);
180+
181+
if (--desc->flush_desc->stream_counter == 0) {
182+
uct_invoke_completion(desc->flush_desc->comp, UCS_OK);
183+
ucs_free(desc->flush_desc);
184+
}
185+
}
186+
187+
/* Flush is done by enqueueing flush events on all active streams and wait
188+
* for them to finish. On each flush event completion, we decrement a shared
189+
* counter and once it hits zero, flush is completed. */
190+
ucs_status_t
191+
uct_cuda_base_ep_flush(uct_ep_h tl_ep, unsigned flags, uct_completion_t *comp)
192+
{
193+
uct_base_ep_t *ep = ucs_derived_of(tl_ep, uct_base_ep_t);
194+
uct_cuda_iface_t *iface = ucs_derived_of(tl_ep->iface, uct_cuda_iface_t);
195+
uct_cuda_flush_desc_t *flush_desc;
196+
uct_cuda_flush_stream_desc_t *flush_stream_desc;
197+
uct_cuda_queue_desc_t *q_desc;
198+
uct_cuda_event_desc_t *event_desc;
199+
ucs_queue_head_t *event_queue;
200+
ucs_queue_iter_t iter;
201+
unsigned stream_index;
202+
203+
if (ucs_queue_is_empty(&iface->active_queue)) {
204+
UCT_TL_EP_STAT_FLUSH(ep);
205+
return UCS_OK;
206+
}
207+
208+
if (comp == NULL) {
209+
goto out;
210+
}
211+
212+
/* Allocate base flush descriptor */
213+
flush_desc = ucs_malloc(sizeof(*flush_desc), "cuda_flush_desc");
214+
if (flush_desc == NULL) {
215+
return UCS_ERR_NO_MEMORY;
216+
}
217+
218+
flush_desc->comp = comp;
219+
flush_desc->stream_counter = 0;
220+
221+
/* For each active stream, init a flush event and enqueue it on the
222+
* stream */
223+
ucs_queue_for_each(q_desc, &iface->active_queue, queue) {
224+
flush_stream_desc = ucs_mpool_get(&iface->flush_mpool);
225+
if (flush_stream_desc == NULL) {
226+
goto error;
227+
}
228+
229+
flush_stream_desc->flush_desc = flush_desc;
230+
flush_stream_desc->comp.func = uct_cuda_base_stream_flushed_cb;
231+
flush_stream_desc->comp.count = 1;
232+
flush_stream_desc->super.comp = &flush_stream_desc->comp;
233+
ucs_queue_push(&q_desc->event_queue, &flush_stream_desc->super.queue);
234+
flush_desc->stream_counter++;
235+
}
236+
237+
out:
238+
UCT_TL_EP_STAT_FLUSH_WAIT(ep);
239+
return UCS_INPROGRESS;
240+
241+
error:
242+
/* Rollback enqueued items in case of error */
243+
for (iter = ucs_queue_iter_begin(&iface->active_queue), stream_index = 0;
244+
stream_index < flush_desc->stream_counter;
245+
iter = ucs_queue_iter_next(iter), ++stream_index) {
246+
event_queue = &ucs_queue_iter_elem(q_desc, iter, queue)->event_queue;
247+
event_desc = ucs_queue_tail_elem_non_empty(event_queue,
248+
uct_cuda_event_desc_t,
249+
queue);
250+
251+
ucs_queue_remove(event_queue, &event_desc->queue);
252+
ucs_mpool_put((uct_cuda_flush_stream_desc_t*)event_desc);
253+
}
254+
255+
ucs_free(flush_desc);
256+
return UCS_ERR_NO_MEMORY;
257+
}
258+
259+
static UCS_F_ALWAYS_INLINE int
260+
uct_cuda_base_event_is_flush(uct_cuda_event_desc_t *event)
261+
{
262+
return (event->comp != NULL) &&
263+
(event->comp->func == uct_cuda_base_stream_flushed_cb);
264+
}
265+
176266
static UCS_F_ALWAYS_INLINE unsigned
177267
uct_cuda_base_progress_event_queue(uct_cuda_iface_t *iface,
178268
ucs_queue_head_t *queue_head,
@@ -183,13 +273,18 @@ uct_cuda_base_progress_event_queue(uct_cuda_iface_t *iface,
183273

184274
ucs_queue_for_each_extract(cuda_event, queue_head, queue,
185275
(count < max_events) &&
186-
(cuEventQuery(cuda_event->event) == CUDA_SUCCESS)) {
276+
(ucs_likely(uct_cuda_base_event_is_flush(
277+
cuda_event)) ||
278+
(cuEventQuery(cuda_event->event) == CUDA_SUCCESS))) {
187279
ucs_trace_data("cuda event %p completed", cuda_event);
188280
if (cuda_event->comp != NULL) {
189281
uct_invoke_completion(cuda_event->comp, UCS_OK);
190282
}
191283

192-
iface->ops->complete_event(&iface->super.super, cuda_event);
284+
if (ucs_likely(!uct_cuda_base_event_is_flush(cuda_event))) {
285+
iface->ops->complete_event(&iface->super.super, cuda_event);
286+
}
287+
193288
ucs_mpool_put(cuda_event);
194289
count++;
195290
}
@@ -352,19 +447,41 @@ static void uct_cuda_base_ctx_rsc_destroy(uct_cuda_iface_t *iface,
352447
iface->ops->destroy_rsc(&iface->super.super, ctx_rsc);
353448
}
354449

450+
static ucs_mpool_ops_t uct_cuda_flush_desc_mpool_ops = {
451+
.chunk_alloc = ucs_mpool_chunk_malloc,
452+
.chunk_release = ucs_mpool_chunk_free,
453+
.obj_init = NULL,
454+
.obj_cleanup = NULL,
455+
.obj_str = NULL
456+
};
457+
355458
UCS_CLASS_INIT_FUNC(uct_cuda_iface_t, uct_iface_ops_t *tl_ops,
356459
uct_iface_internal_ops_t *ops, uct_md_h md,
357460
uct_worker_h worker, const uct_iface_params_t *params,
358461
const uct_iface_config_t *tl_config,
359462
const char *dev_name)
360463
{
464+
ucs_mpool_params_t mp_params;
465+
ucs_status_t status;
466+
361467
UCS_CLASS_CALL_SUPER_INIT(uct_base_iface_t, tl_ops, ops, md, worker, params,
362468
tl_config UCS_STATS_ARG(params->stats_root)
363469
UCS_STATS_ARG(dev_name));
364470

365471
self->eventfd = UCS_ASYNC_EVENTFD_INVALID_FD;
366472
kh_init_inplace(cuda_ctx_rscs, &self->ctx_rscs);
367473
ucs_queue_head_init(&self->active_queue);
474+
475+
ucs_mpool_params_reset(&mp_params);
476+
mp_params.elem_size = sizeof(uct_cuda_flush_stream_desc_t);
477+
mp_params.ops = &uct_cuda_flush_desc_mpool_ops;
478+
mp_params.name = "cuda_flush_descriptors";
479+
480+
status = ucs_mpool_init(&mp_params, &self->flush_mpool);
481+
if (status != UCS_OK) {
482+
return status;
483+
}
484+
368485
return UCS_OK;
369486
}
370487

@@ -381,6 +498,7 @@ static UCS_CLASS_CLEANUP_FUNC(uct_cuda_iface_t)
381498

382499
kh_destroy_inplace(cuda_ctx_rscs, &self->ctx_rscs);
383500
ucs_async_eventfd_destroy(self->eventfd);
501+
ucs_mpool_cleanup(&self->flush_mpool, 1);
384502
}
385503

386504
UCS_CLASS_DEFINE(uct_cuda_iface_t, uct_base_iface_t);

src/uct/cuda/base/cuda_iface.h

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,23 @@ typedef struct {
9595
} uct_cuda_event_desc_t;
9696

9797

98+
/* Base flush descriptor */
99+
typedef struct {
100+
/* How many streams are currently active */
101+
uint32_t stream_counter;
102+
uct_completion_t *comp;
103+
} uct_cuda_flush_desc_t;
104+
105+
106+
/* Stream Flush descriptor */
107+
typedef struct {
108+
uct_cuda_event_desc_t super;
109+
/* Pointer to base flush descriptor */
110+
uct_cuda_flush_desc_t *flush_desc;
111+
uct_completion_t comp;
112+
} uct_cuda_flush_stream_desc_t;
113+
114+
98115
typedef struct {
99116
/* CUDA context handle */
100117
CUcontext ctx;
@@ -130,6 +147,8 @@ typedef struct {
130147
/* list of queues which require progress */
131148
ucs_queue_head_t active_queue;
132149
uct_cuda_iface_ops_t *ops;
150+
/* Pool for flush events */
151+
ucs_mpool_t flush_mpool;
133152

134153
struct {
135154
unsigned max_events;
@@ -148,6 +167,8 @@ unsigned uct_cuda_base_iface_progress(uct_iface_h tl_iface);
148167
ucs_status_t uct_cuda_base_iface_flush(uct_iface_h tl_iface, unsigned flags,
149168
uct_completion_t *comp);
150169

170+
ucs_status_t
171+
uct_cuda_base_ep_flush(uct_ep_h tl_ep, unsigned flags, uct_completion_t *comp);
151172
ucs_status_t
152173
uct_cuda_base_query_devices_common(
153174
uct_md_h md, uct_device_type_t dev_type,

src/uct/cuda/cuda_copy/cuda_copy_iface.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ static uct_iface_ops_t uct_cuda_copy_iface_ops = {
158158
.ep_put_zcopy = uct_cuda_copy_ep_put_zcopy,
159159
.ep_pending_add = (uct_ep_pending_add_func_t)ucs_empty_function_return_busy,
160160
.ep_pending_purge = (uct_ep_pending_purge_func_t)ucs_empty_function,
161-
.ep_flush = uct_base_ep_flush,
161+
.ep_flush = uct_cuda_base_ep_flush,
162162
.ep_fence = uct_base_ep_fence,
163163
.ep_create = UCS_CLASS_NEW_FUNC_NAME(uct_cuda_copy_ep_t),
164164
.ep_destroy = UCS_CLASS_DELETE_FUNC_NAME(uct_cuda_copy_ep_t),

src/uct/cuda/cuda_ipc/cuda_ipc_iface.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -336,7 +336,7 @@ static uct_iface_ops_t uct_cuda_ipc_iface_ops = {
336336
.ep_put_zcopy = uct_cuda_ipc_ep_put_zcopy,
337337
.ep_pending_add = (uct_ep_pending_add_func_t)ucs_empty_function_return_busy,
338338
.ep_pending_purge = (uct_ep_pending_purge_func_t)ucs_empty_function,
339-
.ep_flush = uct_base_ep_flush,
339+
.ep_flush = uct_cuda_base_ep_flush,
340340
.ep_fence = uct_base_ep_fence,
341341
.ep_check = (uct_ep_check_func_t)ucs_empty_function_return_unsupported,
342342
.ep_create = UCS_CLASS_NEW_FUNC_NAME(uct_cuda_ipc_ep_t),

0 commit comments

Comments
 (0)