@@ -840,6 +840,7 @@ static void _common_ucx_tls_cleanup(_tlocal_table_t *tls)
840840 size = tls -> ctx_tbl_size ;
841841 for (i = 0 ; i < size ; i ++ ) {
842842 if (NULL != tls -> ctx_tbl [i ]-> gctx ){
843+ assert (tls -> ctx_tbl [i ]-> refcnt == 0 );
843844 _tlocal_ctx_record_cleanup (tls -> ctx_tbl [i ]);
844845 }
845846 free (tls -> ctx_tbl [i ]);
@@ -909,6 +910,11 @@ _tlocal_ctx_record_cleanup(_tlocal_ctx_t *ctx_rec)
909910 if (NULL == ctx_rec -> gctx ) {
910911 return OPAL_SUCCESS ;
911912 }
913+
914+ if (ctx_rec -> refcnt > 0 ) {
915+ return OPAL_SUCCESS ;
916+ }
917+
912918 /* Remove myself from the communication context structure
913919 * This may result in context release as we are using
914920 * delayed cleanup */
@@ -934,7 +940,7 @@ _tlocal_add_ctx(_tlocal_table_t *tls, opal_common_ucx_ctx_t *ctx)
934940 /* Try to find available record in the TLS table
935941 * In parallel perform deferred cleanups */
936942 for (i = 0 ; i < tls -> ctx_tbl_size ; i ++ ) {
937- if (NULL != tls -> ctx_tbl [i ]-> gctx ) {
943+ if (NULL != tls -> ctx_tbl [i ]-> gctx && tls -> ctx_tbl [ i ] -> refcnt == 0 ) {
938944 if (tls -> ctx_tbl [i ]-> gctx -> released ) {
939945 /* Found dirty record, need to clean first */
940946 _tlocal_ctx_record_cleanup (tls -> ctx_tbl [i ]);
@@ -1059,6 +1065,10 @@ _tlocal_mem_record_cleanup(_tlocal_mem_t *mem_rec)
10591065 free (mem_rec -> mem_tls_ptr );
10601066 }
10611067
1068+ assert (mem_rec -> ctx_rec != NULL );
1069+ OPAL_ATOMIC_ADD_FETCH32 (& mem_rec -> ctx_rec -> refcnt , -1 );
1070+ assert (mem_rec -> ctx_rec -> refcnt >= 0 );
1071+
10621072 free (mem_rec -> mem );
10631073
10641074 memset (mem_rec , 0 , sizeof (* mem_rec ));
@@ -1107,6 +1117,9 @@ static _tlocal_mem_t *_tlocal_add_mem(_tlocal_table_t *tls,
11071117 WPOOL_DBG_OUT ("tls = %p, ctx = %p\n" ,
11081118 (void * )tls , (void * )mem -> ctx );
11091119
1120+ tls -> mem_tbl [free_idx ]-> ctx_rec = ctx_rec ;
1121+ OPAL_ATOMIC_ADD_FETCH32 (& ctx_rec -> refcnt , 1 );
1122+
11101123 tls -> mem_tbl [free_idx ]-> mem -> worker = ctx_rec -> winfo ;
11111124 tls -> mem_tbl [free_idx ]-> mem -> rkeys = calloc (mem -> ctx -> comm_size ,
11121125 sizeof (* tls -> mem_tbl [free_idx ]-> mem -> rkeys ));
0 commit comments