diff --git a/src/plugins/ucx/ucx_backend.cpp b/src/plugins/ucx/ucx_backend.cpp index 9be55fb28..d2e402e84 100644 --- a/src/plugins/ucx/ucx_backend.cpp +++ b/src/plugins/ucx/ucx_backend.cpp @@ -265,32 +265,6 @@ void nixlUcxEngine::vramFiniCtx() cudaCtx.reset(); } -/**************************************** - * UCX request management -*****************************************/ - - -class nixlUcxIntReq { -public: - operator nixlUcxReq() noexcept { - return static_cast(this); - } - - void - setConnection(nixlUcxConnection *conn) { - conn_ = conn; - } - - nixl_status_t - checkConnection(size_t ep_id) const { - NIXL_ASSERT(conn_) << "Connection is not set"; - return conn_->getEp(ep_id)->checkTxState(); - } - -private: - nixlUcxConnection *conn_; -}; - /**************************************** * Backend request management *****************************************/ @@ -298,7 +272,7 @@ class nixlUcxIntReq { class nixlUcxBackendH : public nixlBackendReqH { private: std::set connections_; - std::vector requests_; + std::vector requests_; nixlUcxWorker *worker; size_t worker_id; @@ -313,26 +287,54 @@ class nixlUcxBackendH : public nixlBackendReqH { }; std::optional notif; -public: - auto& notification() { - return notif; + nixl_status_t + checkConnection(nixl_status_t status = NIXL_SUCCESS) const { + NIXL_ASSERT(!connections_.empty()); + for (const auto &conn : connections_) { + nixl_status_t conn_status = conn->getEp(worker_id)->checkTxState(); + if (conn_status != NIXL_SUCCESS) { + return conn_status; + } + } + return status; } +public: nixlUcxBackendH(nixlUcxWorker *worker, size_t worker_id) : worker(worker), worker_id(worker_id) {} + auto & + notification() { + return notif; + } + void reserve(size_t size) { requests_.reserve(size); } - void - append(nixlUcxReq req, ucx_connection_ptr_t conn) { - auto req_int = static_cast(req); - req_int->setConnection(conn.get()); - requests_.push_back(req_int); + nixl_status_t + append(nixl_status_t status, nixlUcxReq req, ucx_connection_ptr_t conn) { connections_.insert(conn); + switch (status) { + case NIXL_IN_PROG: + requests_.push_back(req); + break; + case NIXL_SUCCESS: + // Nothing to do + break; + default: + // Error. Release all previously initiated ops and exit: + release(); + return status; + } + return NIXL_SUCCESS; + } + + const std::set & + getConnections() const { + return connections_; } virtual bool @@ -343,7 +345,7 @@ class nixlUcxBackendH : public nixlBackendReqH { virtual nixl_status_t release() { // TODO: Error log: uncompleted requests found! Cancelling ... - for (nixlUcxIntReq *req : requests_) { + for (nixlUcxReq req : requests_) { nixl_status_t ret = ucx_status_to_nixl(ucp_request_check_status(req)); if (ret == NIXL_IN_PROG) { // TODO: Need process this properly. @@ -370,20 +372,19 @@ class nixlUcxBackendH : public nixlBackendReqH { /* If last request is incomplete, return NIXL_IN_PROG early without * checking other requests */ - nixlUcxIntReq *req = requests_.back(); + nixlUcxReq req = requests_.back(); nixl_status_t ret = ucx_status_to_nixl(ucp_request_check_status(req)); if (ret == NIXL_IN_PROG) { return NIXL_IN_PROG; } else if (ret != NIXL_SUCCESS) { - nixl_status_t conn_status = req->checkConnection(worker_id); - return (conn_status == NIXL_SUCCESS) ? ret : conn_status; + return checkConnection(ret); } /* Last request completed successfully, all the others must be in the * same state. TODO: remove extra checks? */ size_t incomplete_reqs = 0; nixl_status_t out_ret = NIXL_SUCCESS; - for (nixlUcxIntReq *req : requests_) { + for (nixlUcxReq req : requests_) { nixl_status_t ret = ucx_status_to_nixl(ucp_request_check_status(req)); if (__builtin_expect(ret == NIXL_SUCCESS, 0)) { worker->reqRelease(req); @@ -394,8 +395,7 @@ class nixlUcxBackendH : public nixlBackendReqH { requests_[incomplete_reqs++] = req; } else { // Any other ret value is ERR and will be returned - nixl_status_t conn_status = req->checkConnection(worker_id); - out_ret = (conn_status == NIXL_SUCCESS) ? ret : conn_status; + out_ret = checkConnection(ret); } } @@ -1102,7 +1102,7 @@ nixlUcxEngine::nixlUcxEngine(const nixlBackendInitParams &init_params) } uc = std::make_unique( - devs, sizeof(nixlUcxIntReq), init_params.enableProgTh, num_workers, init_params.syncMode); + devs, init_params.enableProgTh, num_workers, init_params.syncMode); for (size_t i = 0; i < num_workers; i++) { uws.emplace_back(std::make_unique(*uc, err_handling_mode)); @@ -1324,24 +1324,6 @@ nixl_status_t nixlUcxEngine::unloadMD (nixlBackendMD* input) { * Data movement *****************************************/ -static nixl_status_t -_retHelper(nixl_status_t ret, nixlUcxBackendH *hndl, nixlUcxReq &req, ucx_connection_ptr_t conn) { - /* if transfer wasn't immediately completed */ - switch(ret) { - case NIXL_IN_PROG: - hndl->append(req, conn); - case NIXL_SUCCESS: - // Nothing to do - break; - default: - // Error. Release all previously initiated ops and exit: - hndl->release(); - return ret; - } - - return NIXL_SUCCESS; -} - size_t nixlUcxEngine::getWorkerId() const { auto it = tlsSharedWorkerMap().find(this); @@ -1461,6 +1443,56 @@ nixl_status_t nixlUcxEngine::estimateXferCost (const nixl_xfer_op_t &operation, return NIXL_SUCCESS; } +nixlUcxEngine::batchResult +nixlUcxEngine::sendXferRangeBatch(nixlUcxEp &ep, + nixl_xfer_op_t operation, + const nixl_meta_dlist_t &local, + const nixl_meta_dlist_t &remote, + size_t worker_id, + size_t start_idx, + size_t end_idx) { + batchResult result = {NIXL_SUCCESS, 0, nullptr}; + + for (size_t i = start_idx; i < end_idx; ++i) { + void *laddr = (void *)local[i].addr; + size_t lsize = local[i].len; + uint64_t raddr = static_cast(remote[i].addr); + NIXL_ASSERT(lsize == remote[i].len); + + auto lmd = static_cast(local[i].metadataP); + auto rmd = static_cast(remote[i].metadataP); + auto &rmd_ep = rmd->conn->getEp(worker_id); + if (__builtin_expect(rmd_ep.get() != &ep, 0)) { + break; + } + + ++result.size; + nixlUcxReq req; + nixl_status_t ret = operation == NIXL_READ ? + ep.read(raddr, rmd->getRkey(worker_id), laddr, lmd->mem, lsize, req) : + ep.write(laddr, lmd->mem, raddr, rmd->getRkey(worker_id), lsize, req); + + if (ret == NIXL_IN_PROG) { + if (__builtin_expect(result.req != nullptr, 1)) { + ucp_request_free(result.req); + } + result.req = req; + } else if (ret != NIXL_SUCCESS) { + result.status = ret; + if (result.req != nullptr) { + ucp_request_free(result.req); + result.req = nullptr; + } + break; + } + } + + if (result.status == NIXL_SUCCESS && result.req) { + result.status = NIXL_IN_PROG; + } + return result; +} + nixl_status_t nixlUcxEngine::sendXferRange(const nixl_xfer_op_t &operation, const nixl_meta_dlist_t &local, @@ -1470,54 +1502,44 @@ nixlUcxEngine::sendXferRange(const nixl_xfer_op_t &operation, size_t start_idx, size_t end_idx) const { nixlUcxBackendH *intHandle = (nixlUcxBackendH *)handle; - nixlUcxPrivateMetadata *lmd; - nixlUcxPublicMetadata *rmd; - nixl_status_t ret; - nixlUcxReq req; size_t workerId = intHandle->getWorkerId(); + nixl_status_t ret; - // Reserve space for the requests, +2 for flush and completion - intHandle->reserve(end_idx - start_idx + 2); + if (operation != NIXL_WRITE && operation != NIXL_READ) { + return NIXL_ERR_INVALID_PARAM; + } - for (size_t i = start_idx; i < end_idx; i++) { - void *laddr = (void*) local[i].addr; - size_t lsize = local[i].len; - uint64_t raddr = (uint64_t)remote[i].addr; - size_t rsize = remote[i].len; + /* Assuming we have a single EP, we need 3 requests: one pending request, + * one flush request, and one notification request */ + intHandle->reserve(3); - lmd = (nixlUcxPrivateMetadata*) local[i].metadataP; - rmd = (nixlUcxPublicMetadata*) remote[i].metadataP; + for (size_t i = start_idx; i < end_idx;) { + /* Send requests to a single EP */ + auto rmd = static_cast(remote[i].metadataP); auto &ep = rmd->conn->getEp(workerId); + auto result = sendXferRangeBatch(*ep, operation, local, remote, workerId, i, end_idx); - if (lsize != rsize) { - return NIXL_ERR_INVALID_PARAM; - } - - switch (operation) { - case NIXL_READ: - ret = ep->read(raddr, rmd->getRkey(workerId), laddr, lmd->mem, lsize, req); - break; - case NIXL_WRITE: - ret = ep->write(laddr, lmd->mem, raddr, rmd->getRkey(workerId), lsize, req); - break; - default: - return NIXL_ERR_INVALID_PARAM; - } - - if (_retHelper(ret, intHandle, req, rmd->conn)) { + /* Append a single pending request for the entire EP batch */ + ret = intHandle->append(result.status, result.req, rmd->conn); + if (ret != NIXL_SUCCESS) { return ret; } + + i += result.size; } /* * Flush keeps intHandle non-empty until the operation is actually * completed, which can happen after local requests completion. + * We need to flush all distinct connections to ensure that the operation + * is actually completed. */ - rmd = (nixlUcxPublicMetadata *)remote[0].metadataP; - ret = rmd->conn->getEp(workerId)->flushEp(req); - - if (_retHelper(ret, intHandle, req, rmd->conn)) { - return ret; + for (auto &conn : intHandle->getConnections()) { + nixlUcxReq req; + ret = conn->getEp(workerId)->flushEp(req); + if (intHandle->append(ret, req, conn) != NIXL_SUCCESS) { + return ret; + } } return NIXL_SUCCESS; @@ -1557,7 +1579,7 @@ nixlUcxEngine::postXfer(const nixl_xfer_op_t &operation, opt_args->notifMsg, rmd->conn->getEp(int_handle->getWorkerId()), &req); - if (_retHelper(ret, int_handle, req, rmd->conn)) { + if (int_handle->append(ret, req, rmd->conn) != NIXL_SUCCESS) { return ret; } @@ -1594,8 +1616,8 @@ nixl_status_t nixlUcxEngine::checkXfer (nixlBackendReqH* handle) const nixl_status_t status = notifSendPriv(notif->agent, notif->payload, conn->getEp(intHandle->getWorkerId()), &req); notif.reset(); - status = _retHelper(status, intHandle, req, conn); - if (status != NIXL_SUCCESS) { + + if (intHandle->append(status, req, conn) != NIXL_SUCCESS) { return status; } diff --git a/src/plugins/ucx/ucx_backend.h b/src/plugins/ucx/ucx_backend.h index bf9e6546d..90b8766b1 100644 --- a/src/plugins/ucx/ucx_backend.h +++ b/src/plugins/ucx/ucx_backend.h @@ -294,6 +294,21 @@ class nixlUcxEngine : public nixlBackendEngine { ucx_connection_ptr_t getConnection(const std::string &remote_agent) const; + struct batchResult { + nixl_status_t status; + size_t size; + nixlUcxReq req; + }; + + static batchResult + sendXferRangeBatch(nixlUcxEp &ep, + nixl_xfer_op_t operation, + const nixl_meta_dlist_t &local, + const nixl_meta_dlist_t &remote, + size_t worker_id, + size_t start_idx, + size_t end_idx); + /* UCX data */ std::unique_ptr uc; std::vector> uws; diff --git a/src/utils/ucx/ucx_utils.cpp b/src/utils/ucx/ucx_utils.cpp index d337e7247..75d634243 100644 --- a/src/utils/ucx/ucx_utils.cpp +++ b/src/utils/ucx/ucx_utils.cpp @@ -406,7 +406,6 @@ bool nixlUcxMtLevelIsSupported(const nixl_ucx_mt_t mt_type) noexcept } nixlUcxContext::nixlUcxContext(std::vector devs, - size_t req_size, bool prog_thread, unsigned long num_workers, nixl_thread_sync_t sync_mode) { @@ -429,11 +428,6 @@ nixlUcxContext::nixlUcxContext(std::vector devs, ucp_params.features |= UCP_FEATURE_WAKEUP; ucp_params.mt_workers_shared = num_workers > 1 ? 1 : 0; - if (req_size) { - ucp_params.request_size = req_size; - ucp_params.field_mask |= UCP_PARAM_FIELD_REQUEST_SIZE; - } - nixl::ucx::config config; /* If requested, restrict the set of network devices */ diff --git a/src/utils/ucx/ucx_utils.h b/src/utils/ucx/ucx_utils.h index 44c06ca44..477a70e7b 100644 --- a/src/utils/ucx/ucx_utils.h +++ b/src/utils/ucx/ucx_utils.h @@ -199,7 +199,6 @@ class nixlUcxContext { public: nixlUcxContext(std::vector devices, - size_t req_size, bool prog_thread, unsigned long num_workers, nixl_thread_sync_t sync_mode); diff --git a/test/unit/utils/ucx/ucx_am_test.cpp b/test/unit/utils/ucx/ucx_am_test.cpp index 33d62eac8..9646b8f6d 100644 --- a/test/unit/utils/ucx/ucx_am_test.cpp +++ b/test/unit/utils/ucx/ucx_am_test.cpp @@ -53,13 +53,10 @@ int main() vector devs; devs.push_back("mlx5_0"); - nixlUcxContext c[2] = {{devs, 0, false, 1, nixl_thread_sync_t::NIXL_THREAD_SYNC_NONE}, - {devs, 0, false, 1, nixl_thread_sync_t::NIXL_THREAD_SYNC_NONE}}; + nixlUcxContext c[2] = {{devs, false, 1, nixl_thread_sync_t::NIXL_THREAD_SYNC_NONE}, + {devs, false, 1, nixl_thread_sync_t::NIXL_THREAD_SYNC_NONE}}; - nixlUcxWorker w[2] = { - nixlUcxWorker(c[0]), - nixlUcxWorker(c[1]) - }; + nixlUcxWorker w[2] = {nixlUcxWorker(c[0]), nixlUcxWorker(c[1])}; std::unique_ptr ep[2]; nixlUcxReq req; uint64_t buffer; diff --git a/test/unit/utils/ucx/ucx_worker_test.cpp b/test/unit/utils/ucx/ucx_worker_test.cpp index bde8b6d18..9cd2e04e3 100644 --- a/test/unit/utils/ucx/ucx_worker_test.cpp +++ b/test/unit/utils/ucx/ucx_worker_test.cpp @@ -73,13 +73,10 @@ int main() // TODO: pass dev name for testing // in CI it would be goot to test both SHM and IB //devs.push_back("mlx5_0"); - nixlUcxContext c[2] = {{devs, 0, false, 1, nixl_thread_sync_t::NIXL_THREAD_SYNC_NONE}, - {devs, 0, false, 1, nixl_thread_sync_t::NIXL_THREAD_SYNC_NONE}}; + nixlUcxContext c[2] = {{devs, false, 1, nixl_thread_sync_t::NIXL_THREAD_SYNC_NONE}, + {devs, false, 1, nixl_thread_sync_t::NIXL_THREAD_SYNC_NONE}}; - nixlUcxWorker w[2] = { - nixlUcxWorker(c[0]), - nixlUcxWorker(c[1]) - }; + nixlUcxWorker w[2] = {nixlUcxWorker(c[0]), nixlUcxWorker(c[1])}; std::unique_ptr ep[2]; nixlUcxMem mem[2]; std::unique_ptr rkey[2];