Skip to content

Commit f8ba9bc

Browse files
authored
libfabric: Add CUDA memory registration support with fi_mr_regattr (#960)
1 parent 52aee5d commit f8ba9bc

File tree

5 files changed

+75
-13
lines changed

5 files changed

+75
-13
lines changed

src/utils/libfabric/libfabric_common.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ getAvailableNetworkDevices() {
5050
hints->mode = FI_CONTEXT;
5151
hints->ep_attr->type = FI_EP_RDM;
5252

53-
int ret = fi_getinfo(FI_VERSION(1, 9), NULL, NULL, 0, hints, &info);
53+
int ret = fi_getinfo(FI_VERSION(1, 18), NULL, NULL, 0, hints, &info);
5454
if (ret) {
5555
NIXL_ERROR << "fi_getinfo failed " << fi_strerror(-ret);
5656
fi_freeinfo(hints);

src/utils/libfabric/libfabric_rail.cpp

Lines changed: 62 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -390,7 +390,8 @@ nixlLibfabricRail::nixlLibfabricRail(const std::string &device,
390390
provider_name(provider),
391391
blocking_cq_sread_supported(true),
392392
control_request_pool_(NIXL_LIBFABRIC_CONTROL_REQUESTS_PER_RAIL, id),
393-
data_request_pool_(NIXL_LIBFABRIC_DATA_REQUESTS_PER_RAIL, id) {
393+
data_request_pool_(NIXL_LIBFABRIC_DATA_REQUESTS_PER_RAIL, id),
394+
provider_supports_hmem_(false) {
394395
// Initialize all pointers to nullptr
395396
info = nullptr;
396397
fabric = nullptr;
@@ -411,7 +412,7 @@ nixlLibfabricRail::nixlLibfabricRail(const std::string &device,
411412
throw std::runtime_error("Failed to allocate fi_info for rail " + std::to_string(rail_id));
412413
}
413414
hints->caps = 0;
414-
hints->caps = FI_MSG | FI_RMA;
415+
hints->caps = FI_MSG | FI_RMA | FI_HMEM; // Try with FI_HMEM first
415416
hints->caps |= FI_LOCAL_COMM | FI_REMOTE_COMM;
416417
hints->mode = FI_CONTEXT;
417418
hints->ep_attr->type = FI_EP_RDM;
@@ -429,12 +430,32 @@ nixlLibfabricRail::nixlLibfabricRail(const std::string &device,
429430
hints->domain_attr->name = strdup(device_name.c_str());
430431
hints->domain_attr->threading = FI_THREAD_SAFE;
431432
try {
432-
// Get fabric info for this specific device
433-
int ret = fi_getinfo(FI_VERSION(1, 9), NULL, NULL, 0, hints, &info);
434-
if (ret) {
435-
NIXL_ERROR << "fi_getinfo failed for rail " << rail_id << ": " << fi_strerror(-ret);
436-
throw std::runtime_error("fi_getinfo failed for rail " + std::to_string(rail_id));
433+
// Get fabric info for this specific device - first try with FI_HMEM
434+
int ret = fi_getinfo(FI_VERSION(1, 18), NULL, NULL, 0, hints, &info);
435+
436+
// If no provider found with FI_HMEM, retry without it
437+
if (ret || !info) {
438+
NIXL_INFO << "No provider found with FI_HMEM capability for rail " << rail_id
439+
<< ", retrying without FI_HMEM";
440+
441+
// Retry without FI_HMEM
442+
hints->caps = FI_MSG | FI_RMA;
443+
hints->caps |= FI_LOCAL_COMM | FI_REMOTE_COMM;
444+
445+
ret = fi_getinfo(FI_VERSION(1, 18), NULL, NULL, 0, hints, &info);
446+
if (ret) {
447+
NIXL_ERROR << "fi_getinfo failed for rail " << rail_id << ": " << fi_strerror(-ret);
448+
throw std::runtime_error("fi_getinfo failed for rail " + std::to_string(rail_id));
449+
}
450+
451+
provider_supports_hmem_ = false;
452+
NIXL_INFO << "Using provider without FI_HMEM support for rail " << rail_id;
453+
} else {
454+
// Provider found with FI_HMEM
455+
provider_supports_hmem_ = true;
456+
NIXL_INFO << "Using provider with FI_HMEM support for rail " << rail_id;
437457
}
458+
438459
// Create fabric for this rail
439460
ret = fi_fabric(info->fabric_attr, &fabric, NULL);
440461
if (ret) {
@@ -1254,6 +1275,8 @@ nixlLibfabricRail::postRead(void *local_buffer,
12541275
nixl_status_t
12551276
nixlLibfabricRail::registerMemory(void *buffer,
12561277
size_t length,
1278+
nixl_mem_t mem_type,
1279+
int gpu_id,
12571280
struct fid_mr **mr_out,
12581281
uint64_t *key_out) const {
12591282
if (!buffer || !mr_out || !key_out) {
@@ -1294,8 +1317,38 @@ nixlLibfabricRail::registerMemory(void *buffer,
12941317
<< " buffer=" << buffer << " length=" << length << " access_flags=0x" << std::hex
12951318
<< provider_access_flags << std::dec << " requested_key=" << requested_key;
12961319

1297-
int ret =
1298-
fi_mr_reg(domain, buffer, length, provider_access_flags, 0, requested_key, 0, &mr, NULL);
1320+
// Use fi_mr_regattr for enhanced memory registration control
1321+
struct fi_mr_attr mr_attr = {};
1322+
mr_attr.access = provider_access_flags;
1323+
mr_attr.offset = 0;
1324+
mr_attr.requested_key = requested_key;
1325+
mr_attr.context = nullptr;
1326+
mr_attr.auth_key_size = 0;
1327+
mr_attr.auth_key = nullptr;
1328+
1329+
// Set HMEM interface based on memory type and provider capability
1330+
if (mem_type == VRAM_SEG) {
1331+
if (provider_supports_hmem_) {
1332+
mr_attr.iface = FI_HMEM_CUDA;
1333+
mr_attr.device.cuda = gpu_id;
1334+
NIXL_DEBUG << "CUDA memory registration - iface: FI_HMEM_CUDA, device.cuda: " << gpu_id;
1335+
} else {
1336+
NIXL_WARN << "VRAM memory requested but provider does not support FI_HMEM - falling "
1337+
"back to system memory registration";
1338+
mr_attr.iface = FI_HMEM_SYSTEM;
1339+
}
1340+
} else {
1341+
mr_attr.iface = FI_HMEM_SYSTEM;
1342+
NIXL_DEBUG << "System memory registration - iface: FI_HMEM_SYSTEM";
1343+
}
1344+
1345+
struct iovec iov;
1346+
iov.iov_base = buffer;
1347+
iov.iov_len = length;
1348+
mr_attr.mr_iov = &iov;
1349+
mr_attr.iov_count = 1;
1350+
1351+
int ret = fi_mr_regattr(domain, &mr_attr, 0, &mr);
12991352
if (ret) {
13001353
NIXL_ERROR << "fi_mr_reg failed on rail " << rail_id << ": " << fi_strerror(-ret)
13011354
<< " (buffer=" << buffer << ", length=" << length

src/utils/libfabric/libfabric_rail.h

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,12 @@ class nixlLibfabricRail {
276276
// Memory registration methods
277277
/** Register memory buffer with libfabric */
278278
nixl_status_t
279-
registerMemory(void *buffer, size_t length, struct fid_mr **mr_out, uint64_t *key_out) const;
279+
registerMemory(void *buffer,
280+
size_t length,
281+
nixl_mem_t mem_type,
282+
int gpu_id,
283+
struct fid_mr **mr_out,
284+
uint64_t *key_out) const;
280285

281286
/** Deregister memory from libfabric */
282287
nixl_status_t
@@ -393,6 +398,9 @@ class nixlLibfabricRail {
393398
ControlRequestPool control_request_pool_;
394399
DataRequestPool data_request_pool_;
395400

401+
// Provider capability flags
402+
bool provider_supports_hmem_;
403+
396404

397405
nixl_status_t
398406
processCompletionQueueEntry(struct fi_cq_data_entry *comp) const;

src/utils/libfabric/libfabric_rail_manager.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -428,7 +428,8 @@ nixlLibfabricRailManager::registerMemory(void *buffer,
428428

429429
struct fid_mr *mr;
430430
uint64_t key;
431-
nixl_status_t status = data_rails_[rail_idx]->registerMemory(buffer, length, &mr, &key);
431+
nixl_status_t status =
432+
data_rails_[rail_idx]->registerMemory(buffer, length, mem_type, gpu_id, &mr, &key);
432433
if (status != NIXL_SUCCESS) {
433434
NIXL_ERROR << "Failed to register memory on rail " << rail_idx;
434435
// Cleanup already registered MRs

src/utils/libfabric/libfabric_topology.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -381,7 +381,7 @@ nixlLibfabricTopology::buildPcieToLibfabricMapping() {
381381
// This ensures consistency between device discovery and PCIe mapping
382382
hints->fabric_attr->prov_name = strdup(provider_name.c_str());
383383

384-
int ret = fi_getinfo(FI_VERSION(1, 9), NULL, NULL, 0, hints, &info);
384+
int ret = fi_getinfo(FI_VERSION(1, 18), NULL, NULL, 0, hints, &info);
385385
if (ret) {
386386
NIXL_ERROR << "fi_getinfo failed for PCIe mapping with provider " << provider_name << ": "
387387
<< fi_strerror(-ret);

0 commit comments

Comments
 (0)