Skip to content

Commit acbcae8

Browse files
authored
Merge pull request #7052 from hoopoepg/topic/pml-ucx-fix-datatype-leak-v3.0
v3.0.x: pml_ucx: add ompi datatype attribute to release ucp_datatype - v3.0
2 parents f90eda1 + 4e9bdf5 commit acbcae8

File tree

4 files changed

+97
-16
lines changed

4 files changed

+97
-16
lines changed

ompi/mca/pml/ucx/pml_ucx.c

Lines changed: 52 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
#include "opal/runtime/opal.h"
1818
#include "opal/mca/pmix/pmix.h"
19+
#include "ompi/attribute/attribute.h"
1920
#include "ompi/message/message.h"
2021
#include "ompi/mca/pml/base/pml_base_bsend.h"
2122
#include "pml_ucx_request.h"
@@ -184,9 +185,9 @@ int mca_pml_ucx_close(void)
184185
int mca_pml_ucx_init(void)
185186
{
186187
ucp_worker_params_t params;
187-
ucs_status_t status;
188188
ucp_worker_attr_t attr;
189-
int rc;
189+
ucs_status_t status;
190+
int i, rc;
190191

191192
PML_UCX_VERBOSE(1, "mca_pml_ucx_init");
192193

@@ -203,30 +204,34 @@ int mca_pml_ucx_init(void)
203204
&ompi_pml_ucx.ucp_worker);
204205
if (UCS_OK != status) {
205206
PML_UCX_ERROR("Failed to create UCP worker");
206-
return OMPI_ERROR;
207+
rc = OMPI_ERROR;
208+
goto err;
207209
}
208210

209211
attr.field_mask = UCP_WORKER_ATTR_FIELD_THREAD_MODE;
210212
status = ucp_worker_query(ompi_pml_ucx.ucp_worker, &attr);
211213
if (UCS_OK != status) {
212-
ucp_worker_destroy(ompi_pml_ucx.ucp_worker);
213-
ompi_pml_ucx.ucp_worker = NULL;
214214
PML_UCX_ERROR("Failed to query UCP worker thread level");
215-
return OMPI_ERROR;
215+
rc = OMPI_ERROR;
216+
goto err_destroy_worker;
216217
}
217218

218-
if (ompi_mpi_thread_multiple && attr.thread_mode != UCS_THREAD_MODE_MULTI) {
219+
if (ompi_mpi_thread_multiple && (attr.thread_mode != UCS_THREAD_MODE_MULTI)) {
219220
/* UCX does not support multithreading, disqualify current PML for now */
220221
/* TODO: we should let OMPI to fallback to THREAD_SINGLE mode */
221-
ucp_worker_destroy(ompi_pml_ucx.ucp_worker);
222-
ompi_pml_ucx.ucp_worker = NULL;
223222
PML_UCX_ERROR("UCP worker does not support MPI_THREAD_MULTIPLE");
224-
return OMPI_ERROR;
223+
rc = OMPI_ERR_NOT_SUPPORTED;
224+
goto err_destroy_worker;
225225
}
226226

227227
rc = mca_pml_ucx_send_worker_address();
228228
if (rc < 0) {
229-
return rc;
229+
goto err_destroy_worker;
230+
}
231+
232+
ompi_pml_ucx.datatype_attr_keyval = MPI_KEYVAL_INVALID;
233+
for (i = 0; i < OMPI_DATATYPE_MAX_PREDEFINED; ++i) {
234+
ompi_pml_ucx.predefined_types[i] = PML_UCX_DATATYPE_INVALID;
230235
}
231236

232237
/* Initialize the free lists */
@@ -242,15 +247,34 @@ int mca_pml_ucx_init(void)
242247
PML_UCX_VERBOSE(2, "created ucp context %p, worker %p",
243248
(void *)ompi_pml_ucx.ucp_context,
244249
(void *)ompi_pml_ucx.ucp_worker);
245-
return OMPI_SUCCESS;
250+
return rc;
251+
252+
err_destroy_worker:
253+
ucp_worker_destroy(ompi_pml_ucx.ucp_worker);
254+
ompi_pml_ucx.ucp_worker = NULL;
255+
err:
256+
return OMPI_ERROR;
246257
}
247258

248259
int mca_pml_ucx_cleanup(void)
249260
{
261+
int i;
262+
250263
PML_UCX_VERBOSE(1, "mca_pml_ucx_cleanup");
251264

252265
opal_progress_unregister(mca_pml_ucx_progress);
253266

267+
if (ompi_pml_ucx.datatype_attr_keyval != MPI_KEYVAL_INVALID) {
268+
ompi_attr_free_keyval(TYPE_ATTR, &ompi_pml_ucx.datatype_attr_keyval, false);
269+
}
270+
271+
for (i = 0; i < OMPI_DATATYPE_MAX_PREDEFINED; ++i) {
272+
if (ompi_pml_ucx.predefined_types[i] != PML_UCX_DATATYPE_INVALID) {
273+
ucp_dt_destroy(ompi_pml_ucx.predefined_types[i]);
274+
ompi_pml_ucx.predefined_types[i] = PML_UCX_DATATYPE_INVALID;
275+
}
276+
}
277+
254278
ompi_pml_ucx.completed_send_req.req_state = OMPI_REQUEST_INVALID;
255279
OMPI_REQUEST_FINI(&ompi_pml_ucx.completed_send_req);
256280
OBJ_DESTRUCT(&ompi_pml_ucx.completed_send_req);
@@ -448,6 +472,22 @@ int mca_pml_ucx_del_procs(struct ompi_proc_t **procs, size_t nprocs)
448472

449473
int mca_pml_ucx_enable(bool enable)
450474
{
475+
ompi_attribute_fn_ptr_union_t copy_fn;
476+
ompi_attribute_fn_ptr_union_t del_fn;
477+
int ret;
478+
479+
/* Create a key for adding custom attributes to datatypes */
480+
copy_fn.attr_datatype_copy_fn =
481+
(MPI_Type_internal_copy_attr_function*)MPI_TYPE_NULL_COPY_FN;
482+
del_fn.attr_datatype_delete_fn = mca_pml_ucx_datatype_attr_del_fn;
483+
ret = ompi_attr_create_keyval(TYPE_ATTR, copy_fn, del_fn,
484+
&ompi_pml_ucx.datatype_attr_keyval, NULL, 0,
485+
NULL);
486+
if (ret != OMPI_SUCCESS) {
487+
PML_UCX_ERROR("Failed to create keyval for UCX datatypes: %d", ret);
488+
return ret;
489+
}
490+
451491
PML_UCX_FREELIST_INIT(&ompi_pml_ucx.persistent_reqs,
452492
mca_pml_ucx_persistent_request_t,
453493
128, -1, 128);

ompi/mca/pml/ucx/pml_ucx.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "ompi/mca/pml/pml.h"
1616
#include "ompi/mca/pml/base/base.h"
1717
#include "ompi/datatype/ompi_datatype.h"
18+
#include "ompi/datatype/ompi_datatype_internal.h"
1819
#include "ompi/communicator/communicator.h"
1920
#include "ompi/request/request.h"
2021

@@ -37,6 +38,10 @@ struct mca_pml_ucx_module {
3738
ucp_context_h ucp_context;
3839
ucp_worker_h ucp_worker;
3940

41+
/* Datatypes */
42+
int datatype_attr_keyval;
43+
ucp_datatype_t predefined_types[OMPI_DATATYPE_MPI_MAX_PREDEFINED];
44+
4045
/* Requests */
4146
mca_pml_ucx_freelist_t persistent_reqs;
4247
ompi_request_t completed_send_req;

ompi/mca/pml/ucx/pml_ucx_datatype.c

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include "pml_ucx_datatype.h"
1111

1212
#include "ompi/runtime/mpiruntime.h"
13+
#include "ompi/attribute/attribute.h"
1314

1415
#include <inttypes.h>
1516

@@ -108,12 +109,25 @@ static ucp_generic_dt_ops_t pml_ucx_generic_datatype_ops = {
108109
.finish = pml_ucx_generic_datatype_finish
109110
};
110111

112+
int mca_pml_ucx_datatype_attr_del_fn(ompi_datatype_t* datatype, int keyval,
113+
void *attr_val, void *extra)
114+
{
115+
ucp_datatype_t ucp_datatype = (ucp_datatype_t)attr_val;
116+
117+
PML_UCX_ASSERT((void*)ucp_datatype == datatype->pml_data);
118+
119+
ucp_dt_destroy(ucp_datatype);
120+
datatype->pml_data = PML_UCX_DATATYPE_INVALID;
121+
return OMPI_SUCCESS;
122+
}
123+
111124
ucp_datatype_t mca_pml_ucx_init_datatype(ompi_datatype_t *datatype)
112125
{
113126
ucp_datatype_t ucp_datatype;
114127
ucs_status_t status;
115128
ptrdiff_t lb;
116129
size_t size;
130+
int ret;
117131

118132
ompi_datatype_type_lb(datatype, &lb);
119133

@@ -128,16 +142,33 @@ ucp_datatype_t mca_pml_ucx_init_datatype(ompi_datatype_t *datatype)
128142
}
129143

130144
status = ucp_dt_create_generic(&pml_ucx_generic_datatype_ops,
131-
datatype, &ucp_datatype);
145+
datatype, &ucp_datatype);
132146
if (status != UCS_OK) {
133147
PML_UCX_ERROR("Failed to create UCX datatype for %s", datatype->name);
134148
ompi_mpi_abort(&ompi_mpi_comm_world.comm, 1);
135149
}
136150

151+
datatype->pml_data = ucp_datatype;
152+
153+
/* Add custom attribute, to clean up UCX resources when OMPI datatype is
154+
* released.
155+
*/
156+
if (ompi_datatype_is_predefined(datatype)) {
157+
PML_UCX_ASSERT(datatype->id < OMPI_DATATYPE_MAX_PREDEFINED);
158+
ompi_pml_ucx.predefined_types[datatype->id] = ucp_datatype;
159+
} else {
160+
ret = ompi_attr_set_c(TYPE_ATTR, datatype, &datatype->d_keyhash,
161+
ompi_pml_ucx.datatype_attr_keyval,
162+
(void*)ucp_datatype, false);
163+
if (ret != OMPI_SUCCESS) {
164+
PML_UCX_ERROR("Failed to add UCX datatype attribute for %s: %d",
165+
datatype->name, ret);
166+
ompi_mpi_abort(&ompi_mpi_comm_world.comm, 1);
167+
}
168+
}
169+
137170
PML_UCX_VERBOSE(7, "created generic UCX datatype 0x%"PRIx64, ucp_datatype)
138-
// TODO put this on a list to be destroyed later
139171

140-
datatype->pml_data = ucp_datatype;
141172
return ucp_datatype;
142173
}
143174

ompi/mca/pml/ucx/pml_ucx_datatype.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
#include "pml_ucx.h"
1414

1515

16+
#define PML_UCX_DATATYPE_INVALID 0
17+
1618
struct pml_ucx_convertor {
1719
opal_free_list_item_t super;
1820
ompi_datatype_t *datatype;
@@ -22,14 +24,17 @@ struct pml_ucx_convertor {
2224

2325
ucp_datatype_t mca_pml_ucx_init_datatype(ompi_datatype_t *datatype);
2426

27+
int mca_pml_ucx_datatype_attr_del_fn(ompi_datatype_t* datatype, int keyval,
28+
void *attr_val, void *extra);
29+
2530
OBJ_CLASS_DECLARATION(mca_pml_ucx_convertor_t);
2631

2732

2833
static inline ucp_datatype_t mca_pml_ucx_get_datatype(ompi_datatype_t *datatype)
2934
{
3035
ucp_datatype_t ucp_type = datatype->pml_data;
3136

32-
if (OPAL_LIKELY(ucp_type != 0)) {
37+
if (OPAL_LIKELY(ucp_type != PML_UCX_DATATYPE_INVALID)) {
3338
return ucp_type;
3439
}
3540

0 commit comments

Comments
 (0)