diff --git a/CODEOWNERS b/CODEOWNERS index dc5ff775f..4de4df953 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -17,9 +17,8 @@ CODEOWNERS @ai-dynamo/Devops @ai-dynamo/nixl-maintainers /src/bindings/python @ovidiusm @mkhazraee @roiedanino /src/bindings/rust @roiedanino @gleon99 @mkhazraee -# UCX Plugins -/src/plugins/ucx* @brminich @yosefe @gleon99 -/src/utils/ucx @brminich @yosefe @gleon99 +# UCX Plugin +/src/plugins/ucx @brminich @yosefe @gleon99 # Storage Plugins /src/plugins/posix @w1ldptr @barneuman @etoledano @vvenkates27 diff --git a/src/utils/ucx/config.cpp b/src/plugins/ucx/config.cpp similarity index 66% rename from src/utils/ucx/config.cpp rename to src/plugins/ucx/config.cpp index 8c8397761..5f63cdad6 100644 --- a/src/utils/ucx/config.cpp +++ b/src/plugins/ucx/config.cpp @@ -24,21 +24,21 @@ namespace nixl::ucx { void -config::modify (std::string_view key, std::string_view value) const { - const char *env_val = std::getenv (absl::StrFormat ("UCX_%s", key.data()).c_str()); +config::modify(std::string_view key, std::string_view value) const { + const char *env_val = std::getenv(absl::StrFormat("UCX_%s", key.data()).c_str()); if (env_val) { NIXL_DEBUG << "UCX env var has already been set: " << key << "=" << env_val; } else { - modifyAlways (key, value); + modifyAlways(key, value); } } void -config::modifyAlways (std::string_view key, std::string_view value) const { - const auto status = ucp_config_modify (config_.get(), key.data(), value.data()); +config::modifyAlways(std::string_view key, std::string_view value) const { + const auto status = ucp_config_modify(config_.get(), key.data(), value.data()); if (status != UCS_OK) { NIXL_WARN << "Failed to modify UCX config: " << key << "=" << value << ": " - << ucs_status_string (status); + << ucs_status_string(status); } else { NIXL_DEBUG << "Modified UCX config: " << key << "=" << value; } @@ -47,12 +47,10 @@ config::modifyAlways (std::string_view key, std::string_view value) const { ucp_config_t * config::readUcpConfig() { ucp_config_t *config = nullptr; - const auto status = ucp_config_read (NULL, NULL, &config); + const auto status = ucp_config_read(nullptr, nullptr, &config); if (status != UCS_OK) { - const auto err_str = - std::string ("Failed to create UCX config: ") + ucs_status_string (status); - NIXL_ERROR << err_str; - throw std::runtime_error (err_str); + throw std::runtime_error("Failed to create UCX config: " + + std::string(ucs_status_string(status))); } return config; } diff --git a/src/utils/ucx/config.h b/src/plugins/ucx/config.h similarity index 83% rename from src/utils/ucx/config.h rename to src/plugins/ucx/config.h index 0cd2a3ea9..2c8b88ca9 100644 --- a/src/utils/ucx/config.h +++ b/src/plugins/ucx/config.h @@ -37,18 +37,18 @@ class config { // Modify the config if it is not already set via environment variable void - modify (std::string_view key, std::string_view value) const; + modify(std::string_view key, std::string_view value) const; // Modify the config always void - modifyAlways (std::string_view key, std::string_view value) const; + modifyAlways(std::string_view key, std::string_view value) const; private: [[nodiscard]] static ucp_config_t * readUcpConfig(); - const std::unique_ptr config_{readUcpConfig(), - &ucp_config_release}; + const std::unique_ptr config_{readUcpConfig(), + &ucp_config_release}; }; } // namespace nixl::ucx diff --git a/src/utils/ucx/gpu_xfer_req_h.cpp b/src/plugins/ucx/gpu_xfer_req_h.cpp similarity index 100% rename from src/utils/ucx/gpu_xfer_req_h.cpp rename to src/plugins/ucx/gpu_xfer_req_h.cpp diff --git a/src/utils/ucx/gpu_xfer_req_h.h b/src/plugins/ucx/gpu_xfer_req_h.h similarity index 100% rename from src/utils/ucx/gpu_xfer_req_h.h rename to src/plugins/ucx/gpu_xfer_req_h.h diff --git a/src/plugins/ucx/meson.build b/src/plugins/ucx/meson.build index aebfa92f9..c53f049fd 100644 --- a/src/plugins/ucx/meson.build +++ b/src/plugins/ucx/meson.build @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -ucx_utils_dep = declare_dependency(link_with: ucx_utils_lib, include_directories: utils_inc_dirs ) asio_dep = [dependency('asio', required: true)] compile_flags = [] @@ -21,19 +20,39 @@ if cuda_dep.found() compile_flags = [ '-DHAVE_CUDA' ] endif +ucx_backend_includes = include_directories('.') + +ucx_backend_sources = ['config.cpp', + 'gpu_xfer_req_h.cpp', + 'rkey.cpp', + 'ucx_backend.cpp', + 'ucx_plugin.cpp', + 'ucx_utils.cpp'] + +ucx_backend_dependencies = [asio_dep, + cuda_dep, + nixl_common_dep, + nixl_infra, + serdes_interface, + thread_dep, + ucx_dep] + +ucx_backend_include_directories = [nixl_inc_dirs, + utils_inc_dirs] + if 'UCX' in static_plugins ucx_backend_lib = static_library('UCX', - 'ucx_backend.cpp', 'ucx_backend.h', 'ucx_plugin.cpp', - dependencies: [nixl_infra, ucx_utils_dep, serdes_interface, cuda_dep, ucx_dep, thread_dep, nixl_common_dep, asio_dep], - include_directories: nixl_inc_dirs, + ucx_backend_sources, + dependencies: ucx_backend_dependencies, + include_directories: ucx_backend_include_directories, install: false, cpp_args : compile_flags, name_prefix: 'libplugin_') # Custom prefix for plugin libraries else ucx_backend_lib = shared_library('UCX', - 'ucx_backend.cpp', 'ucx_backend.h', 'ucx_plugin.cpp', - dependencies: [nixl_infra, ucx_utils_dep, serdes_interface, cuda_dep, ucx_dep, thread_dep, nixl_common_dep, asio_dep], - include_directories: nixl_inc_dirs, + ucx_backend_sources, + dependencies: ucx_backend_dependencies, + include_directories: ucx_backend_include_directories, install: true, cpp_args : compile_flags + ['-fPIC'], name_prefix: 'libplugin_', # Custom prefix for plugin libraries @@ -48,4 +67,4 @@ else endif endif -ucx_backend_interface = declare_dependency(link_with: ucx_backend_lib) +ucx_backend_interface = declare_dependency(link_with: ucx_backend_lib, include_directories: ucx_backend_includes) diff --git a/src/utils/ucx/rkey.cpp b/src/plugins/ucx/rkey.cpp similarity index 100% rename from src/utils/ucx/rkey.cpp rename to src/plugins/ucx/rkey.cpp diff --git a/src/utils/ucx/rkey.h b/src/plugins/ucx/rkey.h similarity index 100% rename from src/utils/ucx/rkey.h rename to src/plugins/ucx/rkey.h diff --git a/src/plugins/ucx/ucx_backend.cpp b/src/plugins/ucx/ucx_backend.cpp index 9be55fb28..474a8b155 100644 --- a/src/plugins/ucx/ucx_backend.cpp +++ b/src/plugins/ucx/ucx_backend.cpp @@ -19,7 +19,7 @@ #include "common/nixl_log.h" #include "serdes/serdes.h" #include "common/nixl_log.h" -#include "ucx/gpu_xfer_req_h.h" +#include "gpu_xfer_req_h.h" #include #include diff --git a/src/plugins/ucx/ucx_backend.h b/src/plugins/ucx/ucx_backend.h index bf9e6546d..547dec1d4 100644 --- a/src/plugins/ucx/ucx_backend.h +++ b/src/plugins/ucx/ucx_backend.h @@ -35,8 +35,8 @@ // Local includes #include "common/nixl_time.h" -#include "ucx/rkey.h" -#include "ucx/ucx_utils.h" +#include "rkey.h" +#include "ucx_utils.h" enum ucx_cb_op_t { NOTIF_STR }; diff --git a/src/utils/ucx/ucx_utils.cpp b/src/plugins/ucx/ucx_utils.cpp similarity index 75% rename from src/utils/ucx/ucx_utils.cpp rename to src/plugins/ucx/ucx_utils.cpp index d337e7247..a9000cc60 100644 --- a/src/utils/ucx/ucx_utils.cpp +++ b/src/plugins/ucx/ucx_utils.cpp @@ -48,13 +48,13 @@ get_ucx_backend_common_options() { return params; } -nixl_status_t ucx_status_to_nixl(ucs_status_t status) -{ +nixl_status_t +ucx_status_to_nixl(ucs_status_t status) { if (status == UCS_OK) { return NIXL_SUCCESS; } - switch(status) { + switch (status) { case UCS_INPROGRESS: case UCS_ERR_BUSY: return NIXL_IN_PROG; @@ -112,21 +112,21 @@ ucx_err_mode_from_string(std::string_view s) { static void err_cb_wrapper(void *arg, ucp_ep_h ucp_ep, ucs_status_t status) { - nixlUcxEp *ep = reinterpret_cast(arg); + auto ep = reinterpret_cast(arg); ep->err_cb(ucp_ep, status); } -void nixlUcxEp::err_cb(ucp_ep_h ucp_ep, ucs_status_t status) -{ +void +nixlUcxEp::err_cb(ucp_ep_h ucp_ep, ucs_status_t status) { ucs_status_ptr_t request; NIXL_DEBUG << "ep " << eph << ": state " << state - << ", UCX error handling callback was invoked with status " - << status << " (" << ucs_status_string(status) << ")"; + << ", UCX error handling callback was invoked with status " << status << " (" + << ucs_status_string(status) << ")"; NIXL_ASSERT(eph == ucp_ep); - switch(state) { + switch (state) { case NIXL_UCX_EP_STATE_NULL: case NIXL_UCX_EP_STATE_FAILED: // The error was already handled, nothing to do @@ -145,23 +145,19 @@ void nixlUcxEp::err_cb(ucp_ep_h ucp_ep, ucs_status_t status) std::terminate(); } -void nixlUcxEp::setState(nixl_ucx_ep_state_t new_state) -{ +void +nixlUcxEp::setState(nixl_ucx_ep_state_t new_state) { NIXL_ASSERT(new_state != state); NIXL_DEBUG << "ep " << eph << ": state " << state << " -> " << new_state; state = new_state; } nixl_status_t -nixlUcxEp::closeImpl(ucp_ep_close_flags_t flags) -{ - ucs_status_ptr_t request = nullptr; - ucp_request_param_t req_param = { - .op_attr_mask = UCP_OP_ATTR_FIELD_FLAGS, - .flags = flags - }; +nixlUcxEp::closeImpl(ucp_ep_close_flags_t flags) { + ucs_status_ptr_t request = nullptr; + ucp_request_param_t req_param = {.op_attr_mask = UCP_OP_ATTR_FIELD_FLAGS, .flags = flags}; - switch(state) { + switch (state) { case NIXL_UCX_EP_STATE_NULL: case NIXL_UCX_EP_STATE_DISCONNECTED: // The EP has not been connected, or already disconnected. @@ -192,19 +188,16 @@ nixlUcxEp::closeImpl(ucp_ep_close_flags_t flags) std::terminate(); } -nixlUcxEp::nixlUcxEp(ucp_worker_h worker, void* addr, - ucp_err_handling_mode_t err_handling_mode) -{ +nixlUcxEp::nixlUcxEp(ucp_worker_h worker, void *addr, ucp_err_handling_mode_t err_handling_mode) { ucp_ep_params_t ep_params; nixl_status_t status; - ep_params.field_mask = UCP_EP_PARAM_FIELD_REMOTE_ADDRESS | - UCP_EP_PARAM_FIELD_ERR_HANDLER | - UCP_EP_PARAM_FIELD_ERR_HANDLING_MODE; - ep_params.err_mode = err_handling_mode; - ep_params.err_handler.cb = err_cb_wrapper; - ep_params.err_handler.arg = reinterpret_cast(this); - ep_params.address = reinterpret_cast(addr); + ep_params.field_mask = UCP_EP_PARAM_FIELD_REMOTE_ADDRESS | UCP_EP_PARAM_FIELD_ERR_HANDLER | + UCP_EP_PARAM_FIELD_ERR_HANDLING_MODE; + ep_params.err_mode = err_handling_mode; + ep_params.err_handler.cb = err_cb_wrapper; + ep_params.err_handler.arg = reinterpret_cast(this); + ep_params.address = reinterpret_cast(addr); status = ucx_status_to_nixl(ucp_ep_create(worker, &ep_params, &eph)); if (status == NIXL_SUCCESS) @@ -213,19 +206,17 @@ nixlUcxEp::nixlUcxEp(ucp_worker_h worker, void* addr, throw std::runtime_error("failed to create ep"); } - nixlUcxEp::~nixlUcxEp() - { - nixl_status_t status = disconnect_nb(); - if (status) - NIXL_ERROR << "Failed to disconnect ep with status " << status; - } +nixlUcxEp::~nixlUcxEp() { + nixl_status_t status = disconnect_nb(); + if (status) NIXL_ERROR << "Failed to disconnect ep with status " << status; +} /* =========================================== * EP management * =========================================== */ -nixl_status_t nixlUcxEp::disconnect_nb() -{ +nixl_status_t +nixlUcxEp::disconnect_nb() { nixl_status_t status = closeImpl(ucp_ep_close_flags_t(0)); // At step of disconnect we can ignore the remote disconnect error. @@ -263,7 +254,7 @@ nixlUcxEp::sendAm(unsigned msg_id, ucp_request_param_t param = {0}; param.op_attr_mask |= UCP_OP_ATTR_FIELD_FLAGS; - param.flags = flags; + param.flags = flags; nixl_ucx_am_cb_ctx_ptr_t ctx; if (deleter) { @@ -304,14 +295,13 @@ nixlUcxEp::read(uint64_t raddr, } ucp_request_param_t param = { - .op_attr_mask = UCP_OP_ATTR_FIELD_MEMH | - UCP_OP_ATTR_FLAG_MULTI_SEND, - .memh = mem.memh, + .op_attr_mask = UCP_OP_ATTR_FIELD_MEMH | UCP_OP_ATTR_FLAG_MULTI_SEND, + .memh = mem.memh, }; ucs_status_ptr_t request = ucp_get_nbx(eph, laddr, size, raddr, rkey.get(), ¶m); if (UCS_PTR_IS_PTR(request)) { - req = (void*)request; + req = static_cast(request); return NIXL_IN_PROG; } @@ -331,27 +321,26 @@ nixlUcxEp::write(void *laddr, } ucp_request_param_t param = { - .op_attr_mask = UCP_OP_ATTR_FIELD_MEMH | - UCP_OP_ATTR_FLAG_MULTI_SEND, - .memh = mem.memh, + .op_attr_mask = UCP_OP_ATTR_FIELD_MEMH | UCP_OP_ATTR_FLAG_MULTI_SEND, + .memh = mem.memh, }; ucs_status_ptr_t request = ucp_put_nbx(eph, laddr, size, raddr, rkey.get(), ¶m); if (UCS_PTR_IS_PTR(request)) { - req = (void*)request; + req = static_cast(request); return NIXL_IN_PROG; } return ucx_status_to_nixl(UCS_PTR_STATUS(request)); } -nixl_status_t nixlUcxEp::estimateCost(size_t size, - std::chrono::microseconds &duration, - std::chrono::microseconds &err_margin, - nixl_cost_t &method) -{ +nixl_status_t +nixlUcxEp::estimateCost(size_t size, + std::chrono::microseconds &duration, + std::chrono::microseconds &err_margin, + nixl_cost_t &method) { ucp_ep_evaluate_perf_param_t params = { - .field_mask = UCP_EP_PERF_PARAM_FIELD_MESSAGE_SIZE, + .field_mask = UCP_EP_PERF_PARAM_FIELD_MESSAGE_SIZE, .message_size = size, }; @@ -365,15 +354,16 @@ nixl_status_t nixlUcxEp::estimateCost(size_t size, return NIXL_ERR_BACKEND; } - duration = std::chrono::duration_cast(std::chrono::duration(cost_result.estimated_time)); + duration = std::chrono::duration_cast( + std::chrono::duration(cost_result.estimated_time)); method = nixl_cost_t::ANALYTICAL_BACKEND; // Currently, we do not have a way to estimate the error margin err_margin = std::chrono::microseconds(0); return NIXL_SUCCESS; } -nixl_status_t nixlUcxEp::flushEp(nixlUcxReq &req) -{ +nixl_status_t +nixlUcxEp::flushEp(nixlUcxReq &req) { ucp_request_param_t param; ucs_status_ptr_t request; @@ -381,20 +371,20 @@ nixl_status_t nixlUcxEp::flushEp(nixlUcxReq &req) request = ucp_ep_flush_nbx(eph, ¶m); if (UCS_PTR_IS_PTR(request)) { - req = (void*)request; + req = static_cast(request); return NIXL_IN_PROG; } return ucx_status_to_nixl(UCS_PTR_STATUS(request)); } -bool nixlUcxMtLevelIsSupported(const nixl_ucx_mt_t mt_type) noexcept -{ +bool +nixlUcxMtLevelIsSupported(const nixl_ucx_mt_t mt_type) noexcept { ucp_lib_attr_t attr; attr.field_mask = UCP_LIB_ATTR_FIELD_MAX_THREAD_LEVEL; ucp_lib_query(&attr); - switch(mt_type) { + switch (mt_type) { case nixl_ucx_mt_t::SINGLE: return attr.max_thread_level >= UCS_THREAD_MODE_SERIALIZED; case nixl_ucx_mt_t::CTX: @@ -417,7 +407,8 @@ nixlUcxContext::nixlUcxContext(std::vector devs, // state is properly protected. Progress thread creates internal concurrency in UCX backend // irrespective of nixlAgent synchronization model. mt_type = (sync_mode == nixl_thread_sync_t::NIXL_THREAD_SYNC_RW || prog_thread) ? - nixl_ucx_mt_t::WORKER : nixl_ucx_mt_t::SINGLE; + nixl_ucx_mt_t::WORKER : + nixl_ucx_mt_t::SINGLE; ucp_params.field_mask = UCP_PARAM_FIELD_FEATURES | UCP_PARAM_FIELD_MT_WORKERS_SHARED; ucp_params.features = UCP_FEATURE_RMA | UCP_FEATURE_AMO32 | UCP_FEATURE_AMO64 | UCP_FEATURE_AM; @@ -425,8 +416,7 @@ nixlUcxContext::nixlUcxContext(std::vector devs, ucp_params.features |= UCP_FEATURE_DEVICE; #endif - if (prog_thread) - ucp_params.features |= UCP_FEATURE_WAKEUP; + if (prog_thread) ucp_params.features |= UCP_FEATURE_WAKEUP; ucp_params.mt_workers_shared = num_workers > 1 ? 1 : 0; if (req_size) { @@ -444,74 +434,70 @@ nixlUcxContext::nixlUcxContext(std::vector devs, devs_str += dev + ":1,"; } devs_str.pop_back(); // to remove odd comma after the last device - config.modifyAlways ("NET_DEVICES", devs_str.c_str()); + config.modifyAlways("NET_DEVICES", devs_str.c_str()); } unsigned major_version, minor_version, release_number; ucp_get_version(&major_version, &minor_version, &release_number); - config.modify ("ADDRESS_VERSION", "v2"); - config.modify ("RNDV_THRESH", "inf"); + config.modify("ADDRESS_VERSION", "v2"); + config.modify("RNDV_THRESH", "inf"); unsigned ucp_version = UCP_VERSION(major_version, minor_version); if (ucp_version >= UCP_VERSION(1, 19)) { - config.modify ("MAX_COMPONENT_MDS", "32"); + config.modify("MAX_COMPONENT_MDS", "32"); } if (ucp_version >= UCP_VERSION(1, 20)) { - config.modify ("MAX_RMA_RAILS", "4"); + config.modify("MAX_RMA_RAILS", "4"); } else { - config.modify ("MAX_RMA_RAILS", "2"); + config.modify("MAX_RMA_RAILS", "2"); } - const auto status = ucp_init (&ucp_params, config.getUcpConfig(), &ctx); + const auto status = ucp_init(&ucp_params, config.getUcpConfig(), &ctx); if (status != UCS_OK) { - throw std::runtime_error ("Failed to create UCX context: " + - std::string (ucs_status_string (status))); + throw std::runtime_error("Failed to create UCX context: " + + std::string(ucs_status_string(status))); } } -nixlUcxContext::~nixlUcxContext() -{ +nixlUcxContext::~nixlUcxContext() { ucp_cleanup(ctx); } -namespace -{ - [[nodiscard]] ucs_thread_mode_t toUcsThreadModeChecked(const nixl_ucx_mt_t t) - { - switch(t) { - case nixl_ucx_mt_t::CTX: - return UCS_THREAD_MODE_SINGLE; - case nixl_ucx_mt_t::SINGLE: - return UCS_THREAD_MODE_SERIALIZED; - case nixl_ucx_mt_t::WORKER: - return UCS_THREAD_MODE_MULTI; - } - NIXL_FATAL << "Invalid UCX worker type: " << static_cast>(t); - std::terminate(); - } - - struct nixlUcpWorkerParams - : ucp_worker_params_t - { - explicit nixlUcpWorkerParams(const nixl_ucx_mt_t t) - { - field_mask = UCP_WORKER_PARAM_FIELD_THREAD_MODE; - thread_mode = toUcsThreadModeChecked(t); - } - }; - - static_assert(sizeof(nixlUcpWorkerParams) == sizeof(ucp_worker_params_t)); - -} // namespace +namespace { +[[nodiscard]] ucs_thread_mode_t +toUcsThreadModeChecked(const nixl_ucx_mt_t t) { + switch (t) { + case nixl_ucx_mt_t::CTX: + return UCS_THREAD_MODE_SINGLE; + case nixl_ucx_mt_t::SINGLE: + return UCS_THREAD_MODE_SERIALIZED; + case nixl_ucx_mt_t::WORKER: + return UCS_THREAD_MODE_MULTI; + } + NIXL_FATAL << "Invalid UCX worker type: " + << static_cast>(t); + std::terminate(); +} + +struct nixlUcpWorkerParams : ucp_worker_params_t { + explicit nixlUcpWorkerParams(const nixl_ucx_mt_t t) { + field_mask = UCP_WORKER_PARAM_FIELD_THREAD_MODE; + thread_mode = toUcsThreadModeChecked(t); + } +}; + +static_assert(sizeof(nixlUcpWorkerParams) == sizeof(ucp_worker_params_t)); + +} // namespace ucp_worker * nixlUcxWorker::createUcpWorker(const nixlUcxContext &ctx) { - ucp_worker* worker = nullptr; + ucp_worker *worker = nullptr; const nixlUcpWorkerParams params(ctx.mt_type); const ucs_status_t status = ucp_worker_create(ctx.ctx, ¶ms, &worker); - if(status != UCS_OK) { + if (status != UCS_OK) { throw std::runtime_error(std::string("Failed to create UCX worker: ") + ucs_status_string(status)); } @@ -523,8 +509,8 @@ nixlUcxWorker::nixlUcxWorker(const nixlUcxContext &ctx, ucp_err_handling_mode_t : worker(createUcpWorker(ctx), &ucp_worker_destroy), err_handling_mode_(err_handling_mode) {} -std::string nixlUcxWorker::epAddr() -{ +std::string +nixlUcxWorker::epAddr() { ucp_worker_attr_t wattr; wattr.field_mask = UCP_WORKER_ATTR_FIELD_ADDRESS; @@ -539,11 +525,12 @@ std::string nixlUcxWorker::epAddr() return result; } -absl::StatusOr> nixlUcxWorker::connect(void* addr, std::size_t size) -{ +absl::StatusOr> +nixlUcxWorker::connect(void *addr, std::size_t size) { try { return std::make_unique(worker.get(), addr, err_handling_mode_); - } catch (const std::exception &e) { + } + catch (const std::exception &e) { return absl::UnavailableError(e.what()); } } @@ -553,23 +540,21 @@ absl::StatusOr> nixlUcxWorker::connect(void* addr, st * =========================================== */ -int nixlUcxContext::memReg(void *addr, size_t size, nixlUcxMem &mem, nixl_mem_t nixl_mem_type) -{ - //mem.uw = this; +int +nixlUcxContext::memReg(void *addr, size_t size, nixlUcxMem &mem, nixl_mem_t nixl_mem_type) { mem.base = addr; mem.size = size; ucp_mem_map_params_t mem_params = { - .field_mask = UCP_MEM_MAP_PARAM_FIELD_FLAGS | - UCP_MEM_MAP_PARAM_FIELD_LENGTH | - UCP_MEM_MAP_PARAM_FIELD_ADDRESS, + .field_mask = UCP_MEM_MAP_PARAM_FIELD_FLAGS | UCP_MEM_MAP_PARAM_FIELD_LENGTH | + UCP_MEM_MAP_PARAM_FIELD_ADDRESS, .address = mem.base, - .length = mem.size, + .length = mem.size, }; ucs_status_t status = ucp_mem_map(ctx, &mem_params, &mem.memh); if (status != UCS_OK) { - /* TODOL: MSW_NET_ERROR(priv->net, "failed to ucp_mem_map (%s)\n", ucs_status_string(status)); */ + NIXL_ERROR << "Failed to ucp_mem_map: " << ucs_status_string(status); return -1; } @@ -578,8 +563,7 @@ int nixlUcxContext::memReg(void *addr, size_t size, nixlUcxMem &mem, nixl_mem_t attr.field_mask = UCP_MEM_ATTR_FIELD_MEM_TYPE; status = ucp_mem_query(mem.memh, &attr); if (status != UCS_OK) { - NIXL_ERROR << absl::StrFormat("Failed to ucp_mem_query: %s", - ucs_status_string(status)); + NIXL_ERROR << "Failed to ucp_mem_query: " << ucs_status_string(status); ucp_mem_unmap(ctx, mem.memh); return -1; } @@ -593,14 +577,14 @@ int nixlUcxContext::memReg(void *addr, size_t size, nixlUcxMem &mem, nixl_mem_t return 0; } -std::string nixlUcxContext::packRkey(nixlUcxMem &mem) -{ - void* rkey_buf; +std::string +nixlUcxContext::packRkey(nixlUcxMem &mem) { + void *rkey_buf; std::size_t size; const ucs_status_t status = ucp_rkey_pack(ctx, mem.memh, &rkey_buf, &size); if (status != UCS_OK) { - /* TODO: MSW_NET_ERROR(priv->net, "failed to ucp_rkey_pack (%s)\n", ucs_status_string(status)); */ + NIXL_ERROR << "Failed to ucp_rkey_pack: " << ucs_status_string(status); return {}; } const std::string result = nixlSerDes::_bytesToString(rkey_buf, size); @@ -608,8 +592,8 @@ std::string nixlUcxContext::packRkey(nixlUcxMem &mem) return result; } -void nixlUcxContext::memDereg(nixlUcxMem &mem) -{ +void +nixlUcxContext::memDereg(nixlUcxMem &mem) { ucp_mem_unmap(ctx, mem.memh); } @@ -620,8 +604,6 @@ constexpr std::string_view ucxGpuDeviceApiUnsupported{ } #endif - - size_t nixlUcxContext::getGpuSignalSize() const { #ifdef HAVE_UCX_GPU_DEVICE_API @@ -645,13 +627,12 @@ nixlUcxContext::getGpuSignalSize() const { * Active message handling * =========================================== */ -int nixlUcxWorker::regAmCallback(unsigned msg_id, ucp_am_recv_callback_t cb, void* arg) -{ +int +nixlUcxWorker::regAmCallback(unsigned msg_id, ucp_am_recv_callback_t cb, void *arg) { ucp_am_handler_param_t params = {0}; - params.field_mask = UCP_AM_HANDLER_PARAM_FIELD_ID | - UCP_AM_HANDLER_PARAM_FIELD_CB | - UCP_AM_HANDLER_PARAM_FIELD_ARG; + params.field_mask = UCP_AM_HANDLER_PARAM_FIELD_ID | UCP_AM_HANDLER_PARAM_FIELD_CB | + UCP_AM_HANDLER_PARAM_FIELD_ARG; params.id = msg_id; params.cb = cb; @@ -659,8 +640,8 @@ int nixlUcxWorker::regAmCallback(unsigned msg_id, ucp_am_recv_callback_t cb, voi const ucs_status_t status = ucp_worker_set_am_recv_handler(worker.get(), ¶ms); - if(status != UCS_OK) { - //TODO: error handling + if (status != UCS_OK) { + // TODO: error handling return -1; } return 0; @@ -670,27 +651,27 @@ int nixlUcxWorker::regAmCallback(unsigned msg_id, ucp_am_recv_callback_t cb, voi * Data transfer * =========================================== */ -int nixlUcxWorker::progress() -{ - return ucp_worker_progress(worker.get()); +int +nixlUcxWorker::progress() { + return ucp_worker_progress(worker.get()); } -nixl_status_t nixlUcxWorker::test(nixlUcxReq req) -{ - if(req == nullptr) { +nixl_status_t +nixlUcxWorker::test(nixlUcxReq req) { + if (req == nullptr) { return NIXL_SUCCESS; } ucp_worker_progress(worker.get()); return ucx_status_to_nixl(ucp_request_check_status(req)); } -void nixlUcxWorker::reqRelease(nixlUcxReq req) -{ +void +nixlUcxWorker::reqRelease(nixlUcxReq req) { ucp_request_free(req); } -void nixlUcxWorker::reqCancel(nixlUcxReq req) -{ +void +nixlUcxWorker::reqCancel(nixlUcxReq req) { ucp_request_cancel(worker.get(), req); } diff --git a/src/utils/ucx/ucx_utils.h b/src/plugins/ucx/ucx_utils.h similarity index 71% rename from src/utils/ucx/ucx_utils.h rename to src/plugins/ucx/ucx_utils.h index 44c06ca44..e905a29e0 100644 --- a/src/utils/ucx/ucx_utils.h +++ b/src/plugins/ucx/ucx_utils.h @@ -20,8 +20,7 @@ #include #include -extern "C" -{ +extern "C" { #include } @@ -31,32 +30,27 @@ extern "C" #include "absl/strings/numbers.h" -enum class nixl_ucx_mt_t { - SINGLE, - CTX, - WORKER -}; +enum class nixl_ucx_mt_t { SINGLE, CTX, WORKER }; constexpr std::string_view nixl_ucx_err_handling_param_name = "ucx_error_handling_mode"; template -[[nodiscard]] constexpr auto enumToInteger(const Enum e) noexcept -{ +[[nodiscard]] constexpr auto +enumToInteger(const Enum e) noexcept { static_assert(std::is_enum_v); return std::underlying_type_t(e); } -[[nodiscard]] std::string_view constexpr to_string_view(const nixl_ucx_mt_t t) noexcept -{ - switch(t) { - case nixl_ucx_mt_t::SINGLE: - return "SINGLE"; - case nixl_ucx_mt_t::CTX: - return "CTX"; - case nixl_ucx_mt_t::WORKER: - return "WORKER"; +[[nodiscard]] std::string_view constexpr to_string_view(const nixl_ucx_mt_t t) noexcept { + switch (t) { + case nixl_ucx_mt_t::SINGLE: + return "SINGLE"; + case nixl_ucx_mt_t::CTX: + return "CTX"; + case nixl_ucx_mt_t::WORKER: + return "WORKER"; } - return "INVALID"; // It is not a to_string function's job to validate. + return "INVALID"; // It is not a to_string function's job to validate. } template @@ -77,7 +71,7 @@ nixl_b_params_get(const nixl_b_params_t *custom_params, const std::string &key, } } -using nixlUcxReq = void*; +using nixlUcxReq = void *; namespace nixl::ucx { class rkey; @@ -91,23 +85,29 @@ class nixlUcxEp { NIXL_UCX_EP_STATE_FAILED, NIXL_UCX_EP_STATE_DISCONNECTED }; + private: - ucp_ep_h eph{nullptr}; + ucp_ep_h eph{nullptr}; nixl_ucx_ep_state_t state{NIXL_UCX_EP_STATE_NULL}; - void setState(nixl_ucx_ep_state_t new_state); - nixl_status_t closeImpl(ucp_ep_close_flags_t flags); + void + setState(nixl_ucx_ep_state_t new_state); + nixl_status_t + closeImpl(ucp_ep_close_flags_t flags); /* Connection */ - nixl_status_t disconnect_nb(); + nixl_status_t + disconnect_nb(); static void sendAmCallback(void *request, ucs_status_t status, void *user_data); public: - void err_cb(ucp_ep_h ucp_ep, ucs_status_t status); + void + err_cb(ucp_ep_h ucp_ep, ucs_status_t status); - nixl_status_t checkTxState() const { + nixl_status_t + checkTxState() const { switch (state) { case NIXL_UCX_EP_STATE_CONNECTED: return NIXL_SUCCESS; @@ -120,10 +120,11 @@ class nixlUcxEp { } } - nixlUcxEp(ucp_worker_h worker, void* addr, ucp_err_handling_mode_t err_handling_mode); + nixlUcxEp(ucp_worker_h worker, void *addr, ucp_err_handling_mode_t err_handling_mode); ~nixlUcxEp(); - nixlUcxEp(const nixlUcxEp&) = delete; - nixlUcxEp& operator=(const nixlUcxEp&) = delete; + nixlUcxEp(const nixlUcxEp &) = delete; + nixlUcxEp & + operator=(const nixlUcxEp &) = delete; using am_deleter_t = std::function; @@ -153,11 +154,13 @@ class nixlUcxEp { const nixl::ucx::rkey &rkey, size_t size, nixlUcxReq &req); - nixl_status_t estimateCost(size_t size, - std::chrono::microseconds &duration, - std::chrono::microseconds &err_margin, - nixl_cost_t &method); - nixl_status_t flushEp(nixlUcxReq &req); + nixl_status_t + estimateCost(size_t size, + std::chrono::microseconds &duration, + std::chrono::microseconds &err_margin, + nixl_cost_t &method); + nixl_status_t + flushEp(nixlUcxReq &req); [[nodiscard]] ucp_ep_h getEp() const noexcept { @@ -170,6 +173,7 @@ class nixlUcxMem { void *base; size_t size; ucp_mem_h memh; + public: [[nodiscard]] ucp_mem_h getMemh() const noexcept { @@ -206,9 +210,12 @@ class nixlUcxContext { ~nixlUcxContext(); /* Memory management */ - int memReg(void *addr, size_t size, nixlUcxMem &mem, nixl_mem_t nixl_mem_type); - [[nodiscard]] std::string packRkey(nixlUcxMem &mem); - void memDereg(nixlUcxMem &mem); + int + memReg(void *addr, size_t size, nixlUcxMem &mem, nixl_mem_t nixl_mem_type); + [[nodiscard]] std::string + packRkey(nixlUcxMem &mem); + void + memDereg(nixlUcxMem &mem); /* GPU signal management */ [[nodiscard]] size_t @@ -217,7 +224,8 @@ class nixlUcxContext { friend class nixlUcxWorker; }; -[[nodiscard]] bool nixlUcxMtLevelIsSupported(const nixl_ucx_mt_t) noexcept; +[[nodiscard]] bool +nixlUcxMtLevelIsSupported(const nixl_ucx_mt_t) noexcept; class nixlUcxWorker { public: @@ -225,24 +233,33 @@ class nixlUcxWorker { const nixlUcxContext &, ucp_err_handling_mode_t ucp_err_handling_mode = UCP_ERR_HANDLING_MODE_NONE); - nixlUcxWorker( nixlUcxWorker&& ) = delete; - nixlUcxWorker( const nixlUcxWorker& ) = delete; - void operator=( nixlUcxWorker&& ) = delete; - void operator=( const nixlUcxWorker& ) = delete; + nixlUcxWorker(nixlUcxWorker &&) = delete; + nixlUcxWorker(const nixlUcxWorker &) = delete; + void + operator=(nixlUcxWorker &&) = delete; + void + operator=(const nixlUcxWorker &) = delete; /* Connection */ - [[nodiscard]] std::string epAddr(); - absl::StatusOr> connect(void* addr, size_t size); + [[nodiscard]] std::string + epAddr(); + absl::StatusOr> + connect(void *addr, size_t size); /* Active message handling */ - int regAmCallback(unsigned msg_id, ucp_am_recv_callback_t cb, void* arg); + int + regAmCallback(unsigned msg_id, ucp_am_recv_callback_t cb, void *arg); /* Data access */ - int progress(); - [[nodiscard]] nixl_status_t test(nixlUcxReq req); + int + progress(); + [[nodiscard]] nixl_status_t + test(nixlUcxReq req); - void reqRelease(nixlUcxReq req); - void reqCancel(nixlUcxReq req); + void + reqRelease(nixlUcxReq req); + void + reqCancel(nixlUcxReq req); [[nodiscard]] nixl_status_t arm() const noexcept; @@ -265,7 +282,8 @@ class nixlUcxWorker { [[nodiscard]] nixl_b_params_t get_ucx_backend_common_options(); -nixl_status_t ucx_status_to_nixl(ucs_status_t status); +[[nodiscard]] nixl_status_t +ucx_status_to_nixl(ucs_status_t status); [[nodiscard]] std::string_view ucx_err_mode_to_string(ucp_err_handling_mode_t t); diff --git a/src/utils/meson.build b/src/utils/meson.build index e7ee0a508..3e2b7bf17 100644 --- a/src/utils/meson.build +++ b/src/utils/meson.build @@ -15,9 +15,6 @@ subdir('common') subdir('serdes') -if ucx_dep.found() - subdir('ucx') -endif subdir('stream') subdir('file') diff --git a/src/utils/ucx/meson.build b/src/utils/ucx/meson.build deleted file mode 100644 index cf7016944..000000000 --- a/src/utils/ucx/meson.build +++ /dev/null @@ -1,30 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -ucx_utils_dep = [ ucx_dep, nixl_common_dep, serdes_interface ] - -if cuda_dep.found() - ucx_utils_dep += [ cuda_dep ] -endif - -ucx_utils_inc_dirs = include_directories('.') - -ucx_utils_sources = ['ucx_utils.cpp', 'ucx_utils.h', 'config.h', 'config.cpp', 'rkey.cpp', 'rkey.h', 'gpu_xfer_req_h.cpp', 'gpu_xfer_req_h.h'] - -ucx_utils_lib = library('ucx_utils', - ucx_utils_sources, - dependencies: ucx_utils_dep, - include_directories: [ nixl_inc_dirs, utils_inc_dirs ], - install: true) diff --git a/test/unit/utils/ucx/meson.build b/test/unit/utils/ucx/meson.build index b5354fa06..56f9ae227 100644 --- a/test/unit/utils/ucx/meson.build +++ b/test/unit/utils/ucx/meson.build @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -ucx_utils_dep = [ ucx_dep, nixl_common_deps ] +ucx_utils_dep = [ ucx_dep, nixl_common_deps, ucx_backend_interface] if cuda_dep.found() ucx_utils_dep += [ cuda_dep ] endif @@ -24,14 +24,12 @@ if get_option('buildtype') != 'release' 'ucx_worker_test.cpp', dependencies: ucx_utils_dep, include_directories: [nixl_inc_dirs,utils_inc_dirs], - link_with: ucx_utils_lib, install: true) ucx_am_bin = executable('ucx_am_test', 'ucx_am_test.cpp', dependencies: ucx_utils_dep, include_directories: [nixl_inc_dirs,utils_inc_dirs], - link_with: ucx_utils_lib, install: true) if cuda_dep.found() @@ -40,7 +38,6 @@ if get_option('buildtype') != 'release' dependencies: ucx_utils_dep, include_directories: [nixl_inc_dirs,utils_inc_dirs], cpp_args : '-DUSE_VRAM', - link_with: ucx_utils_lib, install: true) ucx_am_cuda_bin = executable('ucx_am_test_cuda', @@ -48,7 +45,6 @@ if get_option('buildtype') != 'release' dependencies: ucx_utils_dep, include_directories: [nixl_inc_dirs,utils_inc_dirs], cpp_args : '-DUSE_VRAM', - link_with: ucx_utils_lib, install: true) endif diff --git a/test/unit/utils/ucx/ucx_am_test.cpp b/test/unit/utils/ucx/ucx_am_test.cpp index 33d62eac8..9f0c646bf 100644 --- a/test/unit/utils/ucx/ucx_am_test.cpp +++ b/test/unit/utils/ucx/ucx_am_test.cpp @@ -20,7 +20,7 @@ #include #include -#include "ucx/ucx_utils.h" +#include "ucx_utils.h" using namespace std; diff --git a/test/unit/utils/ucx/ucx_worker_test.cpp b/test/unit/utils/ucx/ucx_worker_test.cpp index bde8b6d18..c4caf30c4 100644 --- a/test/unit/utils/ucx/ucx_worker_test.cpp +++ b/test/unit/utils/ucx/ucx_worker_test.cpp @@ -20,8 +20,8 @@ #include #include -#include "ucx/ucx_utils.h" -#include "ucx/rkey.h" +#include "ucx_utils.h" +#include "rkey.h" //TODO: meson conditional build for CUDA //#define USE_VRAM