Skip to content

Commit ff48070

Browse files
authored
Merge pull request #7065 from janjust/master
oshmem: fix race condition on new contexts
2 parents 40f2ec9 + ebb985d commit ff48070

File tree

1 file changed

+38
-21
lines changed

1 file changed

+38
-21
lines changed

oshmem/mca/spml/ucx/spml_ucx.c

Lines changed: 38 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -573,19 +573,17 @@ static inline void _ctx_add(mca_spml_ucx_ctx_array_t *array, mca_spml_ucx_ctx_t
573573
array->ctxs_count++;
574574
}
575575

576-
static inline void _ctx_remove(mca_spml_ucx_ctx_array_t *array, mca_spml_ucx_ctx_t *ctx)
576+
static inline void _ctx_remove(mca_spml_ucx_ctx_array_t *array, mca_spml_ucx_ctx_t *ctx, int i)
577577
{
578-
int i;
579-
580-
for (i = 0; i < array->ctxs_count; i++) {
578+
for (; i < array->ctxs_count; i++) {
581579
if (array->ctxs[i] == ctx) {
582580
array->ctxs[i] = array->ctxs[array->ctxs_count-1];
583581
array->ctxs[array->ctxs_count-1] = NULL;
582+
array->ctxs_count--;
584583
break;
585584
}
586585
}
587586

588-
array->ctxs_count--;
589587
opal_atomic_wmb ();
590588
}
591589

@@ -681,27 +679,45 @@ static int mca_spml_ucx_ctx_create_common(long options, mca_spml_ucx_ctx_t **ucx
681679

682680
int mca_spml_ucx_ctx_create(long options, shmem_ctx_t *ctx)
683681
{
684-
mca_spml_ucx_ctx_t *ucx_ctx;
685-
int rc;
682+
mca_spml_ucx_ctx_t *ucx_ctx = NULL;
683+
mca_spml_ucx_ctx_array_t *idle_array = &mca_spml_ucx.idle_array;
684+
mca_spml_ucx_ctx_array_t *active_array = &mca_spml_ucx.active_array;
685+
int i = 0, rc = OSHMEM_SUCCESS;
686686

687687
/* Take a lock controlling context creation. AUX context may set specific
688688
* UCX parameters affecting worker creation, which are not needed for
689689
* regular contexts. */
690-
pthread_mutex_lock(&mca_spml_ucx.ctx_create_mutex);
691-
rc = mca_spml_ucx_ctx_create_common(options, &ucx_ctx);
692-
pthread_mutex_unlock(&mca_spml_ucx.ctx_create_mutex);
693-
if (rc != OSHMEM_SUCCESS) {
694-
return rc;
695-
}
696-
697-
if (mca_spml_ucx.active_array.ctxs_count == 0) {
698-
opal_progress_register(spml_ucx_ctx_progress);
699-
}
700690

691+
/* Check if we have an idle context to reuse */
701692
SHMEM_MUTEX_LOCK(mca_spml_ucx.internal_mutex);
702-
_ctx_add(&mca_spml_ucx.active_array, ucx_ctx);
693+
for (i = 0; i < idle_array->ctxs_count; i++) {
694+
if (idle_array->ctxs[i]->options & options) {
695+
ucx_ctx = idle_array->ctxs[i];
696+
_ctx_remove(idle_array, ucx_ctx, i);
697+
break;
698+
}
699+
}
703700
SHMEM_MUTEX_UNLOCK(mca_spml_ucx.internal_mutex);
704701

702+
/* If we cannot reuse, create new ctx */
703+
if (ucx_ctx == NULL) {
704+
pthread_mutex_lock(&mca_spml_ucx.ctx_create_mutex);
705+
rc = mca_spml_ucx_ctx_create_common(options, &ucx_ctx);
706+
pthread_mutex_unlock(&mca_spml_ucx.ctx_create_mutex);
707+
if (rc != OSHMEM_SUCCESS) {
708+
return rc;
709+
}
710+
}
711+
712+
if (!(options & SHMEM_CTX_PRIVATE)) {
713+
SHMEM_MUTEX_LOCK(mca_spml_ucx.internal_mutex);
714+
_ctx_add(&mca_spml_ucx.active_array, ucx_ctx);
715+
if (mca_spml_ucx.active_array.ctxs_count == 0) {
716+
opal_progress_register(spml_ucx_ctx_progress);
717+
}
718+
SHMEM_MUTEX_UNLOCK(mca_spml_ucx.internal_mutex);
719+
}
720+
705721
(*ctx) = (shmem_ctx_t)ucx_ctx;
706722
return OSHMEM_SUCCESS;
707723
}
@@ -711,13 +727,14 @@ void mca_spml_ucx_ctx_destroy(shmem_ctx_t ctx)
711727
MCA_SPML_CALL(quiet(ctx));
712728

713729
SHMEM_MUTEX_LOCK(mca_spml_ucx.internal_mutex);
714-
_ctx_remove(&mca_spml_ucx.active_array, (mca_spml_ucx_ctx_t *)ctx);
730+
if (!(((mca_spml_ucx_ctx_t *)ctx)->options & SHMEM_CTX_PRIVATE)) {
731+
_ctx_remove(&mca_spml_ucx.active_array, (mca_spml_ucx_ctx_t *)ctx, 0);
732+
}
715733
_ctx_add(&mca_spml_ucx.idle_array, (mca_spml_ucx_ctx_t *)ctx);
716-
SHMEM_MUTEX_UNLOCK(mca_spml_ucx.internal_mutex);
717-
718734
if (!mca_spml_ucx.active_array.ctxs_count) {
719735
opal_progress_unregister(spml_ucx_ctx_progress);
720736
}
737+
SHMEM_MUTEX_UNLOCK(mca_spml_ucx.internal_mutex);
721738
}
722739

723740
int mca_spml_ucx_get(shmem_ctx_t ctx, void *src_addr, size_t size, void *dst_addr, int src)

0 commit comments

Comments
 (0)