diff --git a/mooncake-transfer-engine/include/config.h b/mooncake-transfer-engine/include/config.h index e4d849c25..260d85bd1 100644 --- a/mooncake-transfer-engine/include/config.h +++ b/mooncake-transfer-engine/include/config.h @@ -48,6 +48,7 @@ struct GlobalConfig { bool use_ipv6 = false; size_t fragment_limit = 16384; bool enable_dest_device_affinity = false; + bool parallel_reg_mr = false; }; void loadGlobalConfig(GlobalConfig &config); diff --git a/mooncake-transfer-engine/src/config.cpp b/mooncake-transfer-engine/src/config.cpp index 61fcc835e..339951b71 100644 --- a/mooncake-transfer-engine/src/config.cpp +++ b/mooncake-transfer-engine/src/config.cpp @@ -258,6 +258,13 @@ void loadGlobalConfig(GlobalConfig &config) { if (std::getenv("MC_ENABLE_DEST_DEVICE_AFFINITY")) { config.enable_dest_device_affinity = true; } + + const char *enable_parallel_reg_mr = + std::getenv("MC_ENABLE_PARALLEL_REG_MR"); + if (enable_parallel_reg_mr) { + LOG(INFO) << "Enable parallel register memory region"; + config.parallel_reg_mr = true; + } } std::string mtuLengthToString(ibv_mtu mtu) { @@ -306,6 +313,8 @@ void dumpGlobalConfig() { LOG(INFO) << "max_wr = " << config.max_wr; LOG(INFO) << "max_inline = " << config.max_inline; LOG(INFO) << "mtu_length = " << mtuLengthToString(config.mtu_length); + LOG(INFO) << "parallel_reg_mr = " + << (config.parallel_reg_mr ? "true" : "false"); } GlobalConfig &globalConfig() { diff --git a/mooncake-transfer-engine/src/transport/rdma_transport/rdma_transport.cpp b/mooncake-transfer-engine/src/transport/rdma_transport/rdma_transport.cpp index e87d14db5..9aeff0b82 100644 --- a/mooncake-transfer-engine/src/transport/rdma_transport/rdma_transport.cpp +++ b/mooncake-transfer-engine/src/transport/rdma_transport/rdma_transport.cpp @@ -92,9 +92,39 @@ int RdmaTransport::registerLocalMemory(void *addr, size_t length, const static int access_rights = IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE | IBV_ACCESS_REMOTE_READ; + if (globalConfig().parallel_reg_mr) { + std::vector> registration_futures; + registration_futures.reserve(context_list_.size()); + + for (auto &context : context_list_) { + registration_futures.emplace_back(std::async( + std::launch::async, [&context, addr, length]() -> int { + return context->registerMemoryRegion(addr, length, + access_rights); + })); + } + + for (size_t i = 0; i < registration_futures.size(); ++i) { + int ret = registration_futures[i].get(); + if (ret) { + LOG(ERROR) << "Failed to register memory region with context " + << i; + return ret; + } + } + } else { + for (size_t i = 0; i < context_list_.size(); ++i) { + int ret = context_list_[i]->registerMemoryRegion(addr, length, + access_rights); + if (ret) { + LOG(ERROR) << "Failed to register memory region with context " + << i; + return ret; + } + } + } + for (auto &context : context_list_) { - int ret = context->registerMemoryRegion(addr, length, access_rights); - if (ret) return ret; buffer_desc.lkey.push_back(context->lkey(addr)); buffer_desc.rkey.push_back(context->rkey(addr)); } @@ -106,19 +136,15 @@ int RdmaTransport::registerLocalMemory(void *addr, size_t length, getMemoryLocation(addr, length); if (entries.empty()) return -1; buffer_desc.name = entries[0].location; - buffer_desc.addr = (uint64_t)addr; - buffer_desc.length = length; - int rc = metadata_->addLocalMemoryBuffer(buffer_desc, update_metadata); - if (rc) return rc; } else { buffer_desc.name = name; - buffer_desc.addr = (uint64_t)addr; - buffer_desc.length = length; - int rc = metadata_->addLocalMemoryBuffer(buffer_desc, update_metadata); - - if (rc) return rc; } + buffer_desc.addr = (uint64_t)addr; + buffer_desc.length = length; + int rc = metadata_->addLocalMemoryBuffer(buffer_desc, update_metadata); + if (rc) return rc; + return 0; } @@ -126,7 +152,35 @@ int RdmaTransport::unregisterLocalMemory(void *addr, bool update_metadata) { int rc = metadata_->removeLocalMemoryBuffer(addr, update_metadata); if (rc) return rc; - for (auto &context : context_list_) context->unregisterMemoryRegion(addr); + if (globalConfig().parallel_reg_mr) { + std::vector> unregistration_futures; + unregistration_futures.reserve(context_list_.size()); + + for (auto &context : context_list_) { + unregistration_futures.emplace_back( + std::async(std::launch::async, [&context, addr]() -> int { + return context->unregisterMemoryRegion(addr); + })); + } + + for (size_t i = 0; i < unregistration_futures.size(); ++i) { + int ret = unregistration_futures[i].get(); + if (ret) { + LOG(ERROR) << "Failed to unregister memory region with context " + << i; + return ret; + } + } + } else { + for (size_t i = 0; i < context_list_.size(); ++i) { + int ret = context_list_[i]->unregisterMemoryRegion(addr); + if (ret) { + LOG(ERROR) << "Failed to unregister memory region with context " + << i; + return ret; + } + } + } return 0; }