@@ -390,7 +390,8 @@ nixlLibfabricRail::nixlLibfabricRail(const std::string &device,
390390 provider_name(provider),
391391 blocking_cq_sread_supported(true ),
392392 control_request_pool_(NIXL_LIBFABRIC_CONTROL_REQUESTS_PER_RAIL, id),
393- data_request_pool_(NIXL_LIBFABRIC_DATA_REQUESTS_PER_RAIL, id) {
393+ data_request_pool_(NIXL_LIBFABRIC_DATA_REQUESTS_PER_RAIL, id),
394+ provider_supports_hmem_(false ) {
394395 // Initialize all pointers to nullptr
395396 info = nullptr ;
396397 fabric = nullptr ;
@@ -411,7 +412,7 @@ nixlLibfabricRail::nixlLibfabricRail(const std::string &device,
411412 throw std::runtime_error (" Failed to allocate fi_info for rail " + std::to_string (rail_id));
412413 }
413414 hints->caps = 0 ;
414- hints->caps = FI_MSG | FI_RMA;
415+ hints->caps = FI_MSG | FI_RMA | FI_HMEM; // Try with FI_HMEM first
415416 hints->caps |= FI_LOCAL_COMM | FI_REMOTE_COMM;
416417 hints->mode = FI_CONTEXT;
417418 hints->ep_attr ->type = FI_EP_RDM;
@@ -429,12 +430,32 @@ nixlLibfabricRail::nixlLibfabricRail(const std::string &device,
429430 hints->domain_attr ->name = strdup (device_name.c_str ());
430431 hints->domain_attr ->threading = FI_THREAD_SAFE;
431432 try {
432- // Get fabric info for this specific device
433- int ret = fi_getinfo (FI_VERSION (1 , 9 ), NULL , NULL , 0 , hints, &info);
434- if (ret) {
435- NIXL_ERROR << " fi_getinfo failed for rail " << rail_id << " : " << fi_strerror (-ret);
436- throw std::runtime_error (" fi_getinfo failed for rail " + std::to_string (rail_id));
433+ // Get fabric info for this specific device - first try with FI_HMEM
434+ int ret = fi_getinfo (FI_VERSION (1 , 18 ), NULL , NULL , 0 , hints, &info);
435+
436+ // If no provider found with FI_HMEM, retry without it
437+ if (ret || !info) {
438+ NIXL_INFO << " No provider found with FI_HMEM capability for rail " << rail_id
439+ << " , retrying without FI_HMEM" ;
440+
441+ // Retry without FI_HMEM
442+ hints->caps = FI_MSG | FI_RMA;
443+ hints->caps |= FI_LOCAL_COMM | FI_REMOTE_COMM;
444+
445+ ret = fi_getinfo (FI_VERSION (1 , 18 ), NULL , NULL , 0 , hints, &info);
446+ if (ret) {
447+ NIXL_ERROR << " fi_getinfo failed for rail " << rail_id << " : " << fi_strerror (-ret);
448+ throw std::runtime_error (" fi_getinfo failed for rail " + std::to_string (rail_id));
449+ }
450+
451+ provider_supports_hmem_ = false ;
452+ NIXL_INFO << " Using provider without FI_HMEM support for rail " << rail_id;
453+ } else {
454+ // Provider found with FI_HMEM
455+ provider_supports_hmem_ = true ;
456+ NIXL_INFO << " Using provider with FI_HMEM support for rail " << rail_id;
437457 }
458+
438459 // Create fabric for this rail
439460 ret = fi_fabric (info->fabric_attr , &fabric, NULL );
440461 if (ret) {
@@ -1254,6 +1275,8 @@ nixlLibfabricRail::postRead(void *local_buffer,
12541275nixl_status_t
12551276nixlLibfabricRail::registerMemory (void *buffer,
12561277 size_t length,
1278+ nixl_mem_t mem_type,
1279+ int gpu_id,
12571280 struct fid_mr **mr_out,
12581281 uint64_t *key_out) const {
12591282 if (!buffer || !mr_out || !key_out) {
@@ -1294,8 +1317,38 @@ nixlLibfabricRail::registerMemory(void *buffer,
12941317 << " buffer=" << buffer << " length=" << length << " access_flags=0x" << std::hex
12951318 << provider_access_flags << std::dec << " requested_key=" << requested_key;
12961319
1297- int ret =
1298- fi_mr_reg (domain, buffer, length, provider_access_flags, 0 , requested_key, 0 , &mr, NULL );
1320+ // Use fi_mr_regattr for enhanced memory registration control
1321+ struct fi_mr_attr mr_attr = {};
1322+ mr_attr.access = provider_access_flags;
1323+ mr_attr.offset = 0 ;
1324+ mr_attr.requested_key = requested_key;
1325+ mr_attr.context = nullptr ;
1326+ mr_attr.auth_key_size = 0 ;
1327+ mr_attr.auth_key = nullptr ;
1328+
1329+ // Set HMEM interface based on memory type and provider capability
1330+ if (mem_type == VRAM_SEG) {
1331+ if (provider_supports_hmem_) {
1332+ mr_attr.iface = FI_HMEM_CUDA;
1333+ mr_attr.device .cuda = gpu_id;
1334+ NIXL_DEBUG << " CUDA memory registration - iface: FI_HMEM_CUDA, device.cuda: " << gpu_id;
1335+ } else {
1336+ NIXL_WARN << " VRAM memory requested but provider does not support FI_HMEM - falling "
1337+ " back to system memory registration" ;
1338+ mr_attr.iface = FI_HMEM_SYSTEM;
1339+ }
1340+ } else {
1341+ mr_attr.iface = FI_HMEM_SYSTEM;
1342+ NIXL_DEBUG << " System memory registration - iface: FI_HMEM_SYSTEM" ;
1343+ }
1344+
1345+ struct iovec iov;
1346+ iov.iov_base = buffer;
1347+ iov.iov_len = length;
1348+ mr_attr.mr_iov = &iov;
1349+ mr_attr.iov_count = 1 ;
1350+
1351+ int ret = fi_mr_regattr (domain, &mr_attr, 0 , &mr);
12991352 if (ret) {
13001353 NIXL_ERROR << " fi_mr_reg failed on rail " << rail_id << " : " << fi_strerror (-ret)
13011354 << " (buffer=" << buffer << " , length=" << length
0 commit comments