-
Notifications
You must be signed in to change notification settings - Fork 183
Advanced request handling optimizations #1009
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 10 commits
9519c2d
abdb803
883ca96
363f98e
a5d1207
0d23918
cd261e3
b6c6bda
203f199
9272ef5
2edc9df
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -265,40 +265,14 @@ void nixlUcxEngine::vramFiniCtx() | |
| cudaCtx.reset(); | ||
| } | ||
|
|
||
| /**************************************** | ||
| * UCX request management | ||
| *****************************************/ | ||
|
|
||
|
|
||
| class nixlUcxIntReq { | ||
| public: | ||
| operator nixlUcxReq() noexcept { | ||
| return static_cast<nixlUcxReq>(this); | ||
| } | ||
|
|
||
| void | ||
| setConnection(nixlUcxConnection *conn) { | ||
| conn_ = conn; | ||
| } | ||
|
|
||
| nixl_status_t | ||
| checkConnection(size_t ep_id) const { | ||
| NIXL_ASSERT(conn_) << "Connection is not set"; | ||
| return conn_->getEp(ep_id)->checkTxState(); | ||
| } | ||
|
|
||
| private: | ||
| nixlUcxConnection *conn_; | ||
| }; | ||
|
|
||
| /**************************************** | ||
| * Backend request management | ||
| *****************************************/ | ||
|
|
||
| class nixlUcxBackendH : public nixlBackendReqH { | ||
| private: | ||
| std::set<ucx_connection_ptr_t> connections_; | ||
| std::vector<nixlUcxIntReq *> requests_; | ||
| std::vector<nixlUcxReq> requests_; | ||
| nixlUcxWorker *worker; | ||
| size_t worker_id; | ||
|
|
||
|
|
@@ -313,26 +287,54 @@ class nixlUcxBackendH : public nixlBackendReqH { | |
| }; | ||
| std::optional<Notif> notif; | ||
|
|
||
| public: | ||
| auto& notification() { | ||
| return notif; | ||
| nixl_status_t | ||
| checkConnection(nixl_status_t status = NIXL_SUCCESS) const { | ||
| NIXL_ASSERT(!connections_.empty()); | ||
| for (const auto &conn : connections_) { | ||
| nixl_status_t conn_status = conn->getEp(worker_id)->checkTxState(); | ||
| if (conn_status != NIXL_SUCCESS) { | ||
| return conn_status; | ||
| } | ||
| } | ||
| return status; | ||
| } | ||
|
|
||
| public: | ||
| nixlUcxBackendH(nixlUcxWorker *worker, size_t worker_id) | ||
| : worker(worker), | ||
| worker_id(worker_id) {} | ||
|
|
||
| auto & | ||
| notification() { | ||
| return notif; | ||
| } | ||
|
|
||
| void | ||
| reserve(size_t size) { | ||
| requests_.reserve(size); | ||
| } | ||
|
|
||
| void | ||
| append(nixlUcxReq req, ucx_connection_ptr_t conn) { | ||
| auto req_int = static_cast<nixlUcxIntReq *>(req); | ||
| req_int->setConnection(conn.get()); | ||
| requests_.push_back(req_int); | ||
| nixl_status_t | ||
| append(nixl_status_t status, nixlUcxReq req, ucx_connection_ptr_t conn) { | ||
| connections_.insert(conn); | ||
| switch (status) { | ||
| case NIXL_IN_PROG: | ||
| requests_.push_back(req); | ||
| break; | ||
| case NIXL_SUCCESS: | ||
| // Nothing to do | ||
| break; | ||
| default: | ||
| // Error. Release all previously initiated ops and exit: | ||
| release(); | ||
| return status; | ||
| } | ||
| return NIXL_SUCCESS; | ||
| } | ||
|
|
||
| const std::set<ucx_connection_ptr_t> & | ||
| getConnections() const { | ||
| return connections_; | ||
| } | ||
|
|
||
| virtual bool | ||
|
|
@@ -343,7 +345,7 @@ class nixlUcxBackendH : public nixlBackendReqH { | |
| virtual nixl_status_t | ||
| release() { | ||
| // TODO: Error log: uncompleted requests found! Cancelling ... | ||
| for (nixlUcxIntReq *req : requests_) { | ||
| for (nixlUcxReq req : requests_) { | ||
| nixl_status_t ret = ucx_status_to_nixl(ucp_request_check_status(req)); | ||
| if (ret == NIXL_IN_PROG) { | ||
| // TODO: Need process this properly. | ||
|
|
@@ -370,20 +372,19 @@ class nixlUcxBackendH : public nixlBackendReqH { | |
|
|
||
| /* If last request is incomplete, return NIXL_IN_PROG early without | ||
| * checking other requests */ | ||
| nixlUcxIntReq *req = requests_.back(); | ||
| nixlUcxReq req = requests_.back(); | ||
| nixl_status_t ret = ucx_status_to_nixl(ucp_request_check_status(req)); | ||
| if (ret == NIXL_IN_PROG) { | ||
| return NIXL_IN_PROG; | ||
| } else if (ret != NIXL_SUCCESS) { | ||
| nixl_status_t conn_status = req->checkConnection(worker_id); | ||
| return (conn_status == NIXL_SUCCESS) ? ret : conn_status; | ||
| return checkConnection(ret); | ||
| } | ||
|
|
||
| /* Last request completed successfully, all the others must be in the | ||
| * same state. TODO: remove extra checks? */ | ||
| size_t incomplete_reqs = 0; | ||
| nixl_status_t out_ret = NIXL_SUCCESS; | ||
| for (nixlUcxIntReq *req : requests_) { | ||
| for (nixlUcxReq req : requests_) { | ||
| nixl_status_t ret = ucx_status_to_nixl(ucp_request_check_status(req)); | ||
| if (__builtin_expect(ret == NIXL_SUCCESS, 0)) { | ||
| worker->reqRelease(req); | ||
|
|
@@ -394,8 +395,7 @@ class nixlUcxBackendH : public nixlBackendReqH { | |
| requests_[incomplete_reqs++] = req; | ||
| } else { | ||
| // Any other ret value is ERR and will be returned | ||
| nixl_status_t conn_status = req->checkConnection(worker_id); | ||
| out_ret = (conn_status == NIXL_SUCCESS) ? ret : conn_status; | ||
| out_ret = checkConnection(ret); | ||
| } | ||
| } | ||
|
|
||
|
|
@@ -1102,7 +1102,7 @@ nixlUcxEngine::nixlUcxEngine(const nixlBackendInitParams &init_params) | |
| } | ||
|
|
||
| uc = std::make_unique<nixlUcxContext>( | ||
| devs, sizeof(nixlUcxIntReq), init_params.enableProgTh, num_workers, init_params.syncMode); | ||
| devs, init_params.enableProgTh, num_workers, init_params.syncMode); | ||
|
|
||
| for (size_t i = 0; i < num_workers; i++) { | ||
| uws.emplace_back(std::make_unique<nixlUcxWorker>(*uc, err_handling_mode)); | ||
|
|
@@ -1324,24 +1324,6 @@ nixl_status_t nixlUcxEngine::unloadMD (nixlBackendMD* input) { | |
| * Data movement | ||
| *****************************************/ | ||
|
|
||
| static nixl_status_t | ||
| _retHelper(nixl_status_t ret, nixlUcxBackendH *hndl, nixlUcxReq &req, ucx_connection_ptr_t conn) { | ||
| /* if transfer wasn't immediately completed */ | ||
| switch(ret) { | ||
| case NIXL_IN_PROG: | ||
| hndl->append(req, conn); | ||
| case NIXL_SUCCESS: | ||
| // Nothing to do | ||
| break; | ||
| default: | ||
| // Error. Release all previously initiated ops and exit: | ||
| hndl->release(); | ||
| return ret; | ||
| } | ||
|
|
||
| return NIXL_SUCCESS; | ||
| } | ||
|
|
||
| size_t | ||
| nixlUcxEngine::getWorkerId() const { | ||
| auto it = tlsSharedWorkerMap().find(this); | ||
|
|
@@ -1461,6 +1443,56 @@ nixl_status_t nixlUcxEngine::estimateXferCost (const nixl_xfer_op_t &operation, | |
| return NIXL_SUCCESS; | ||
| } | ||
|
|
||
| nixlUcxEngine::batchResult | ||
| nixlUcxEngine::sendXferRangeBatch(nixlUcxEp &ep, | ||
| nixl_xfer_op_t operation, | ||
| const nixl_meta_dlist_t &local, | ||
| const nixl_meta_dlist_t &remote, | ||
| size_t worker_id, | ||
| size_t start_idx, | ||
| size_t end_idx) { | ||
| batchResult result = {NIXL_SUCCESS, 0, nullptr}; | ||
|
|
||
| for (size_t i = start_idx; i < end_idx; ++i) { | ||
| void *laddr = (void *)local[i].addr; | ||
| size_t lsize = local[i].len; | ||
| uint64_t raddr = (uint64_t)remote[i].addr; | ||
rakhmets marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| NIXL_ASSERT(lsize == remote[i].len); | ||
|
|
||
| auto lmd = static_cast<nixlUcxPrivateMetadata *>(local[i].metadataP); | ||
| auto rmd = static_cast<nixlUcxPublicMetadata *>(remote[i].metadataP); | ||
| auto &rmd_ep = rmd->conn->getEp(worker_id); | ||
| if (__builtin_expect(rmd_ep.get() != &ep, 0)) { | ||
| break; | ||
| } | ||
|
|
||
| ++result.size; | ||
| nixlUcxReq req; | ||
| nixl_status_t ret = operation == NIXL_READ ? | ||
| ep.read(raddr, rmd->getRkey(worker_id), laddr, lmd->mem, lsize, req) : | ||
| ep.write(laddr, lmd->mem, raddr, rmd->getRkey(worker_id), lsize, req); | ||
|
|
||
| if (ret == NIXL_IN_PROG) { | ||
| if (__builtin_expect(result.req != nullptr, 1)) { | ||
| ucp_request_free(result.req); | ||
| } | ||
| result.req = req; | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. how do we use this request? It can be returned to memory pool by UCX at any moment after the free
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As you can see we don't use freed request at all. Later on we use this last pending request in "waiting for completion" stage (checkXfer/status) in order to:
|
||
| } else if (ret != NIXL_SUCCESS) { | ||
| result.status = ret; | ||
| if (result.req != nullptr) { | ||
| ucp_request_free(result.req); | ||
| result.req = nullptr; | ||
| } | ||
| break; | ||
| } | ||
| } | ||
|
|
||
| if (result.status == NIXL_SUCCESS && result.req) { | ||
| result.status = NIXL_IN_PROG; | ||
| } | ||
| return result; | ||
| } | ||
|
|
||
| nixl_status_t | ||
| nixlUcxEngine::sendXferRange(const nixl_xfer_op_t &operation, | ||
| const nixl_meta_dlist_t &local, | ||
|
|
@@ -1470,54 +1502,44 @@ nixlUcxEngine::sendXferRange(const nixl_xfer_op_t &operation, | |
| size_t start_idx, | ||
| size_t end_idx) const { | ||
| nixlUcxBackendH *intHandle = (nixlUcxBackendH *)handle; | ||
| nixlUcxPrivateMetadata *lmd; | ||
| nixlUcxPublicMetadata *rmd; | ||
| nixl_status_t ret; | ||
| nixlUcxReq req; | ||
| size_t workerId = intHandle->getWorkerId(); | ||
| nixl_status_t ret; | ||
|
|
||
| // Reserve space for the requests, +2 for flush and completion | ||
| intHandle->reserve(end_idx - start_idx + 2); | ||
| if (operation != NIXL_WRITE && operation != NIXL_READ) { | ||
| return NIXL_ERR_INVALID_PARAM; | ||
| } | ||
|
|
||
| for (size_t i = start_idx; i < end_idx; i++) { | ||
| void *laddr = (void*) local[i].addr; | ||
| size_t lsize = local[i].len; | ||
| uint64_t raddr = (uint64_t)remote[i].addr; | ||
| size_t rsize = remote[i].len; | ||
| /* Assuming we have a single EP, we need 3 requests: one pending request, | ||
| * one flush request, and one notification request */ | ||
| intHandle->reserve(3); | ||
|
|
||
| lmd = (nixlUcxPrivateMetadata*) local[i].metadataP; | ||
| rmd = (nixlUcxPublicMetadata*) remote[i].metadataP; | ||
| for (size_t i = start_idx; i < end_idx;) { | ||
| /* Send requests to a single EP */ | ||
| auto rmd = static_cast<nixlUcxPublicMetadata *>(remote[i].metadataP); | ||
| auto &ep = rmd->conn->getEp(workerId); | ||
| auto result = sendXferRangeBatch(*ep, operation, local, remote, workerId, i, end_idx); | ||
rakhmets marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| if (lsize != rsize) { | ||
| return NIXL_ERR_INVALID_PARAM; | ||
| } | ||
|
|
||
| switch (operation) { | ||
| case NIXL_READ: | ||
| ret = ep->read(raddr, rmd->getRkey(workerId), laddr, lmd->mem, lsize, req); | ||
| break; | ||
| case NIXL_WRITE: | ||
| ret = ep->write(laddr, lmd->mem, raddr, rmd->getRkey(workerId), lsize, req); | ||
| break; | ||
| default: | ||
| return NIXL_ERR_INVALID_PARAM; | ||
| } | ||
|
|
||
| if (_retHelper(ret, intHandle, req, rmd->conn)) { | ||
| /* Append a single pending request for the entire EP batch */ | ||
| ret = intHandle->append(result.status, result.req, rmd->conn); | ||
| if (ret != NIXL_SUCCESS) { | ||
| return ret; | ||
| } | ||
|
|
||
| i += result.size; | ||
| } | ||
|
|
||
| /* | ||
| * Flush keeps intHandle non-empty until the operation is actually | ||
| * completed, which can happen after local requests completion. | ||
| * We need to flush all distinct connections to ensure that the operation | ||
| * is actually completed. | ||
| */ | ||
| rmd = (nixlUcxPublicMetadata *)remote[0].metadataP; | ||
| ret = rmd->conn->getEp(workerId)->flushEp(req); | ||
|
|
||
| if (_retHelper(ret, intHandle, req, rmd->conn)) { | ||
| return ret; | ||
| for (auto &conn : intHandle->getConnections()) { | ||
| nixlUcxReq req; | ||
| ret = conn->getEp(workerId)->flushEp(req); | ||
| if (intHandle->append(ret, req, conn) != NIXL_SUCCESS) { | ||
| return ret; | ||
| } | ||
| } | ||
|
|
||
| return NIXL_SUCCESS; | ||
|
|
@@ -1557,7 +1579,7 @@ nixlUcxEngine::postXfer(const nixl_xfer_op_t &operation, | |
| opt_args->notifMsg, | ||
| rmd->conn->getEp(int_handle->getWorkerId()), | ||
| &req); | ||
| if (_retHelper(ret, int_handle, req, rmd->conn)) { | ||
| if (int_handle->append(ret, req, rmd->conn) != NIXL_SUCCESS) { | ||
| return ret; | ||
| } | ||
|
|
||
|
|
@@ -1594,8 +1616,8 @@ nixl_status_t nixlUcxEngine::checkXfer (nixlBackendReqH* handle) const | |
| nixl_status_t status = | ||
| notifSendPriv(notif->agent, notif->payload, conn->getEp(intHandle->getWorkerId()), &req); | ||
| notif.reset(); | ||
| status = _retHelper(status, intHandle, req, conn); | ||
| if (status != NIXL_SUCCESS) { | ||
|
|
||
| if (intHandle->append(status, req, conn) != NIXL_SUCCESS) { | ||
| return status; | ||
| } | ||
|
|
||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.