Skip to content

Commit 6705378

Browse files
authored
Merge pull request open-mpi#13415 from jiaxiyan/pd
ofi: Share the domain among MTL and BTL
2 parents 12bee0f + 69d2737 commit 6705378

File tree

5 files changed

+261
-22
lines changed

5 files changed

+261
-22
lines changed

ompi/mca/mtl/ofi/mtl_ofi_component.c

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -694,6 +694,8 @@ ompi_mtl_ofi_component_init(bool enable_progress_threads,
694694
}
695695

696696
hints->domain_attr->resource_mgmt = FI_RM_ENABLED;
697+
hints->domain_attr->domain = opal_common_ofi.domain;
698+
hints->fabric_attr->fabric = opal_common_ofi.fabric;
697699

698700
/**
699701
* The EFA provider in Libfabric versions prior to 1.10 contains a bug
@@ -715,10 +717,16 @@ ompi_mtl_ofi_component_init(bool enable_progress_threads,
715717
hints_dup->fabric_attr->prov_name = strdup("efa");
716718

717719
ret = fi_getinfo(fi_primary_version, NULL, NULL, 0ULL, hints_dup, &providers);
720+
if (FI_ENODATA == -ret && (hints_dup->fabric_attr->fabric || hints_dup->domain_attr->domain)) {
721+
/* Retry without fabric and domain */
722+
hints_dup->fabric_attr->fabric = NULL;
723+
hints_dup->domain_attr->domain = NULL;
724+
ret = fi_getinfo(fi_primary_version, NULL, NULL, 0ULL, hints_dup, &providers);
725+
}
718726
if (FI_ENOSYS == -ret) {
719727
/* libfabric is not new enough, fallback to use older version of API */
720728
ret = fi_getinfo(fi_alternate_version, NULL, NULL, 0ULL, hints_dup, &providers);
721-
}
729+
}
722730

723731
opal_output_verbose(1, opal_common_ofi.output,
724732
"%s:%d: EFA specific fi_getinfo(): %s\n",
@@ -756,6 +764,11 @@ ompi_mtl_ofi_component_init(bool enable_progress_threads,
756764
0ULL, /* Optional flag */
757765
hints, /* In: Hints to filter providers */
758766
&providers); /* Out: List of matching providers */
767+
if (FI_ENODATA == -ret && (hints->fabric_attr->fabric || hints->domain_attr->domain)) {
768+
hints->fabric_attr->fabric = NULL;
769+
hints->domain_attr->domain = NULL;
770+
ret = fi_getinfo(fi_primary_version, NULL, NULL, 0ULL, hints, &providers);
771+
}
759772
if (FI_ENOSYS == -ret) {
760773
ret = fi_getinfo(fi_alternate_version, NULL, NULL, 0ULL, hints, &providers);
761774
}
@@ -972,9 +985,8 @@ ompi_mtl_ofi_component_init(bool enable_progress_threads,
972985
* instantiate the virtual or physical network. This opens a "fabric
973986
* provider". See man fi_fabric for details.
974987
*/
975-
ret = fi_fabric(prov->fabric_attr, /* In: Fabric attributes */
976-
&ompi_mtl_ofi.fabric, /* Out: Fabric handle */
977-
NULL); /* Optional context for fabric events */
988+
ret = opal_common_ofi_fi_fabric(prov->fabric_attr, /* In: Fabric attributes */
989+
&ompi_mtl_ofi.fabric); /* Out: Fabric handle */
978990
if (0 != ret) {
979991
opal_show_help("help-mtl-ofi.txt", "OFI call fail", true,
980992
"fi_fabric",
@@ -988,10 +1000,9 @@ ompi_mtl_ofi_component_init(bool enable_progress_threads,
9881000
* hardware port/collection of ports. Returns a domain object that can be
9891001
* used to create endpoints. See man fi_domain for details.
9901002
*/
991-
ret = fi_domain(ompi_mtl_ofi.fabric, /* In: Fabric object */
992-
prov, /* In: Provider */
993-
&ompi_mtl_ofi.domain, /* Out: Domain object */
994-
NULL); /* Optional context for domain events */
1003+
ret = opal_common_ofi_fi_domain(ompi_mtl_ofi.fabric, /* In: Fabric object */
1004+
prov, /* In: Provider */
1005+
&ompi_mtl_ofi.domain); /* Out: Domain object */
9951006
if (0 != ret) {
9961007
opal_show_help("help-mtl-ofi.txt", "OFI call fail", true,
9971008
"fi_domain",
@@ -1155,10 +1166,10 @@ ompi_mtl_ofi_component_init(bool enable_progress_threads,
11551166
(void) fi_close((fid_t)ompi_mtl_ofi.ofi_ctxt[0].cq);
11561167
}
11571168
if (ompi_mtl_ofi.domain) {
1158-
(void) fi_close((fid_t)ompi_mtl_ofi.domain);
1169+
(void) opal_common_ofi_domain_release(ompi_mtl_ofi.domain);
11591170
}
11601171
if (ompi_mtl_ofi.fabric) {
1161-
(void) fi_close((fid_t)ompi_mtl_ofi.fabric);
1172+
(void) opal_common_ofi_fabric_release(ompi_mtl_ofi.fabric);
11621173
}
11631174
if (ompi_mtl_ofi.comm_to_context) {
11641175
free(ompi_mtl_ofi.comm_to_context);
@@ -1206,11 +1217,11 @@ ompi_mtl_ofi_finalize(struct mca_mtl_base_module_t *mtl)
12061217
}
12071218
}
12081219

1209-
if ((ret = fi_close((fid_t)ompi_mtl_ofi.domain))) {
1220+
if ((ret = opal_common_ofi_domain_release(ompi_mtl_ofi.domain))) {
12101221
goto finalize_err;
12111222
}
12121223

1213-
if ((ret = fi_close((fid_t)ompi_mtl_ofi.fabric))) {
1224+
if ((ret = opal_common_ofi_fabric_release(ompi_mtl_ofi.fabric))) {
12141225
goto finalize_err;
12151226
}
12161227

opal/mca/btl/ofi/btl_ofi_component.c

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -339,6 +339,12 @@ static mca_btl_base_module_t **mca_btl_ofi_component_init(int *num_btl_modules,
339339
domain_attr.control_progress = progress_mode;
340340
domain_attr.data_progress = progress_mode;
341341

342+
if (enable_mpi_threads) {
343+
domain_attr.threading = FI_THREAD_SAFE;
344+
} else {
345+
domain_attr.threading = FI_THREAD_DOMAIN;
346+
}
347+
342348
/* select endpoint type */
343349
ep_attr.type = FI_EP_RDM;
344350

@@ -359,7 +365,8 @@ static mca_btl_base_module_t **mca_btl_ofi_component_init(int *num_btl_modules,
359365
tx_attr.iov_limit = 1;
360366
rx_attr.iov_limit = 1;
361367

362-
tx_attr.op_flags = FI_DELIVERY_COMPLETE;
368+
tx_attr.op_flags = FI_DELIVERY_COMPLETE | FI_COMPLETION;
369+
rx_attr.op_flags = FI_COMPLETION;
363370

364371
mca_btl_ofi_component.module_count = 0;
365372

@@ -372,9 +379,18 @@ static mca_btl_base_module_t **mca_btl_ofi_component_init(int *num_btl_modules,
372379
no_hmem:
373380
#endif
374381

382+
hints.fabric_attr->fabric = opal_common_ofi.fabric;
383+
hints.domain_attr->domain = opal_common_ofi.domain;
384+
375385
/* Do the query. The earliest version that supports FI_HMEM hints is 1.9.
376386
* The earliest version the explictly allow provider to call CUDA API is 1.18 */
377387
rc = fi_getinfo(FI_VERSION(1, 18), NULL, NULL, 0, &hints, &info_list);
388+
if (FI_ENODATA == -rc && (hints.fabric_attr->fabric || hints.domain_attr->domain)) {
389+
/* Retry without fabric and domain */
390+
hints.fabric_attr->fabric = NULL;
391+
hints.domain_attr->domain = NULL;
392+
rc = fi_getinfo(FI_VERSION(1, 18), NULL, NULL, 0, &hints, &info_list);
393+
}
378394
if (FI_ENOSYS == -rc) {
379395
rc = fi_getinfo(FI_VERSION(1, 9), NULL, NULL, 0, &hints, &info_list);
380396
}
@@ -553,14 +569,14 @@ static int mca_btl_ofi_init_device(struct fi_info *info)
553569
("initializing dev:%s provider:%s", linux_device_name, info->fabric_attr->prov_name));
554570

555571
/* fabric */
556-
rc = fi_fabric(ofi_info->fabric_attr, &fabric, NULL);
572+
rc = opal_common_ofi_fi_fabric(ofi_info->fabric_attr, &fabric);
557573
if (0 != rc) {
558574
BTL_VERBOSE(("%s failed fi_fabric with err=%s", linux_device_name, fi_strerror(-rc)));
559575
goto fail;
560576
}
561577

562578
/* domain */
563-
rc = fi_domain(fabric, ofi_info, &domain, NULL);
579+
rc = opal_common_ofi_fi_domain(fabric, ofi_info, &domain);
564580
if (0 != rc) {
565581
BTL_VERBOSE(("%s failed fi_domain with err=%s", linux_device_name, fi_strerror(-rc)));
566582
goto fail;
@@ -743,11 +759,11 @@ static int mca_btl_ofi_init_device(struct fi_info *info)
743759
}
744760

745761
if (NULL != domain) {
746-
fi_close(&domain->fid);
762+
opal_common_ofi_domain_release(domain);
747763
}
748764

749765
if (NULL != fabric) {
750-
fi_close(&fabric->fid);
766+
opal_common_ofi_fabric_release(fabric);
751767
}
752768
free(module);
753769

opal/mca/btl/ofi/btl_ofi_module.c

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -385,11 +385,11 @@ int mca_btl_ofi_finalize(mca_btl_base_module_t *btl)
385385
}
386386

387387
if (NULL != ofi_btl->domain) {
388-
fi_close(&ofi_btl->domain->fid);
388+
opal_common_ofi_domain_release(ofi_btl->domain);
389389
}
390390

391391
if (NULL != ofi_btl->fabric) {
392-
fi_close(&ofi_btl->fabric->fid);
392+
opal_common_ofi_fabric_release(ofi_btl->fabric);
393393
}
394394

395395
if (NULL != ofi_btl->fabric_info) {

opal/mca/common/ofi/common_ofi.c

Lines changed: 159 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
* reserved.
77
* Copyright (c) 2020-2021 Cisco Systems, Inc. All rights reserved.
88
* Copyright (c) 2021-2023 Nanook Consulting. All rights reserved.
9-
* Copyright (c) 2021 Amazon.com, Inc. or its affiliates. All rights
9+
* Copyright (c) 2021-2025 Amazon.com, Inc. or its affiliates. All rights
1010
* reserved.
1111
* Copyright (c) 2023 UT-Battelle, LLC. All rights reserved.
1212
* $COPYRIGHT$
@@ -42,7 +42,11 @@
4242
extern opal_accelerator_base_module_t opal_accelerator;
4343
opal_common_ofi_module_t opal_common_ofi = {.prov_include = NULL,
4444
.prov_exclude = NULL,
45-
.output = -1};
45+
.output = -1,
46+
.fabric = NULL,
47+
.domain = NULL,
48+
.fabric_ref_count = 0,
49+
.domain_ref_count = 0};
4650
static const char default_prov_exclude_list[] = "shm,sockets,tcp,udp,rstream,usnic,net";
4751
static opal_mutex_t opal_common_ofi_mutex = OPAL_MUTEX_STATIC_INIT;
4852
static int opal_common_ofi_verbose_level = 0;
@@ -1257,3 +1261,156 @@ OPAL_DECLSPEC int opal_common_ofi_fi_getname(fid_t fid, void **addr, size_t *add
12571261
}
12581262
return ret;
12591263
}
1264+
1265+
/**
1266+
* Get or create fabric object
1267+
*
1268+
* Reuses existing fabric from fabric_attr->fabric if available,
1269+
* otherwise creates new fabric using fi_fabric().
1270+
*
1271+
* @param fabric_attr (IN) Fabric attributes
1272+
* @param fabric (OUT) Fabric object (new or existing)
1273+
*
1274+
* @return OPAL_SUCCESS or error code
1275+
*/
1276+
int opal_common_ofi_fi_fabric(struct fi_fabric_attr *fabric_attr,
1277+
struct fid_fabric **fabric)
1278+
{
1279+
int ret;
1280+
1281+
OPAL_THREAD_LOCK(&opal_common_ofi_mutex);
1282+
1283+
if (fabric_attr->fabric) {
1284+
*fabric = fabric_attr->fabric;
1285+
opal_common_ofi.fabric_ref_count++;
1286+
opal_output_verbose(1, opal_common_ofi.output, "Reusing existing fabric: %s",
1287+
fabric_attr->name);
1288+
} else {
1289+
ret = fi_fabric(fabric_attr, fabric, NULL);
1290+
if (0 != ret) {
1291+
OPAL_THREAD_UNLOCK(&opal_common_ofi_mutex);
1292+
return ret;
1293+
}
1294+
opal_common_ofi.fabric = *fabric;
1295+
opal_common_ofi.fabric_ref_count = 1;
1296+
}
1297+
1298+
OPAL_THREAD_UNLOCK(&opal_common_ofi_mutex);
1299+
return OPAL_SUCCESS;
1300+
}
1301+
1302+
/**
1303+
* Get or create domain object
1304+
*
1305+
* Reuses existing domain from info->domain_attr->domain if available,
1306+
* otherwise creates new domain using fi_domain().
1307+
*
1308+
* @param fabric (IN) Fabric object
1309+
* @param info (IN) Provider info
1310+
* @param domain (OUT) Domain object (new or existing)
1311+
*
1312+
* @return OPAL_SUCCESS or OPAL error code
1313+
*/
1314+
int opal_common_ofi_fi_domain(struct fid_fabric *fabric, struct fi_info *info,
1315+
struct fid_domain **domain)
1316+
{
1317+
int ret;
1318+
1319+
OPAL_THREAD_LOCK(&opal_common_ofi_mutex);
1320+
1321+
if (info->domain_attr->domain) {
1322+
*domain = info->domain_attr->domain;
1323+
opal_common_ofi.domain_ref_count++;
1324+
opal_output_verbose(1, opal_common_ofi.output, "Reusing existing domain: %s",
1325+
info->domain_attr->name);
1326+
} else {
1327+
ret = fi_domain(fabric, info, domain, NULL);
1328+
if (0 != ret) {
1329+
OPAL_THREAD_UNLOCK(&opal_common_ofi_mutex);
1330+
return ret;
1331+
}
1332+
opal_common_ofi.domain = *domain;
1333+
opal_common_ofi.domain_ref_count = 1;
1334+
}
1335+
1336+
OPAL_THREAD_UNLOCK(&opal_common_ofi_mutex);
1337+
return OPAL_SUCCESS;
1338+
}
1339+
1340+
/**
1341+
* Release fabric reference
1342+
*
1343+
* Decrements fabric reference count and closes fabric if count reaches zero.
1344+
*
1345+
* @param fabric (IN) Fabric object to release
1346+
*
1347+
* @return OPAL_SUCCESS or error code
1348+
*/
1349+
int opal_common_ofi_fabric_release(struct fid_fabric *fabric)
1350+
{
1351+
int ret = OPAL_SUCCESS;
1352+
1353+
OPAL_THREAD_LOCK(&opal_common_ofi_mutex);
1354+
1355+
if (fabric == opal_common_ofi.fabric && opal_common_ofi.fabric_ref_count > 0) {
1356+
opal_common_ofi.fabric_ref_count--;
1357+
if (opal_common_ofi.fabric_ref_count == 0) {
1358+
ret = fi_close(&fabric->fid);
1359+
if (0 != ret) {
1360+
opal_output_verbose(1, opal_common_ofi.output,
1361+
"%s:%d: fi_close failed for fabric: %s (%d)",
1362+
__FILE__, __LINE__, fi_strerror(-ret), ret);
1363+
}
1364+
opal_common_ofi.fabric = NULL;
1365+
}
1366+
} else {
1367+
ret = fi_close(&fabric->fid);
1368+
if (0 != ret) {
1369+
opal_output_verbose(1, opal_common_ofi.output,
1370+
"%s:%d: fi_close failed for fabric: %s (%d)",
1371+
__FILE__, __LINE__, fi_strerror(-ret), ret);
1372+
}
1373+
}
1374+
1375+
OPAL_THREAD_UNLOCK(&opal_common_ofi_mutex);
1376+
return ret;
1377+
}
1378+
1379+
/**
1380+
* Release domain reference
1381+
*
1382+
* Decrements domain reference count and closes domain if count reaches zero.
1383+
*
1384+
* @param domain (IN) Domain object to release
1385+
*
1386+
* @return OPAL_SUCCESS or error code
1387+
*/
1388+
int opal_common_ofi_domain_release(struct fid_domain *domain)
1389+
{
1390+
int ret = OPAL_SUCCESS;
1391+
1392+
OPAL_THREAD_LOCK(&opal_common_ofi_mutex);
1393+
1394+
if (domain == opal_common_ofi.domain && opal_common_ofi.domain_ref_count > 0) {
1395+
opal_common_ofi.domain_ref_count--;
1396+
if (opal_common_ofi.domain_ref_count == 0) {
1397+
ret = fi_close(&domain->fid);
1398+
if (0 != ret) {
1399+
opal_output_verbose(1, opal_common_ofi.output,
1400+
"%s:%d: fi_close failed for domain: %s (%d)",
1401+
__FILE__, __LINE__, fi_strerror(-ret), ret);
1402+
}
1403+
opal_common_ofi.domain = NULL;
1404+
}
1405+
} else {
1406+
ret = fi_close(&domain->fid);
1407+
if (0 != ret) {
1408+
opal_output_verbose(1, opal_common_ofi.output,
1409+
"%s:%d: fi_close failed for domain: %s (%d)",
1410+
__FILE__, __LINE__, fi_strerror(-ret), ret);
1411+
}
1412+
}
1413+
1414+
OPAL_THREAD_UNLOCK(&opal_common_ofi_mutex);
1415+
return ret;
1416+
}

0 commit comments

Comments
 (0)