Skip to content
Open
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
218 changes: 120 additions & 98 deletions src/plugins/ucx/ucx_backend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -265,40 +265,14 @@ void nixlUcxEngine::vramFiniCtx()
cudaCtx.reset();
}

/****************************************
* UCX request management
*****************************************/


class nixlUcxIntReq {
public:
operator nixlUcxReq() noexcept {
return static_cast<nixlUcxReq>(this);
}

void
setConnection(nixlUcxConnection *conn) {
conn_ = conn;
}

nixl_status_t
checkConnection(size_t ep_id) const {
NIXL_ASSERT(conn_) << "Connection is not set";
return conn_->getEp(ep_id)->checkTxState();
}

private:
nixlUcxConnection *conn_;
};

/****************************************
* Backend request management
*****************************************/

class nixlUcxBackendH : public nixlBackendReqH {
private:
std::set<ucx_connection_ptr_t> connections_;
std::vector<nixlUcxIntReq *> requests_;
std::vector<nixlUcxReq> requests_;
nixlUcxWorker *worker;
size_t worker_id;

Expand All @@ -313,26 +287,54 @@ class nixlUcxBackendH : public nixlBackendReqH {
};
std::optional<Notif> notif;

public:
auto& notification() {
return notif;
nixl_status_t
checkConnection(nixl_status_t status = NIXL_SUCCESS) const {
NIXL_ASSERT(!connections_.empty());
for (const auto &conn : connections_) {
nixl_status_t conn_status = conn->getEp(worker_id)->checkTxState();
if (conn_status != NIXL_SUCCESS) {
return conn_status;
}
}
return status;
}

public:
nixlUcxBackendH(nixlUcxWorker *worker, size_t worker_id)
: worker(worker),
worker_id(worker_id) {}

auto &
notification() {
return notif;
}

void
reserve(size_t size) {
requests_.reserve(size);
}

void
append(nixlUcxReq req, ucx_connection_ptr_t conn) {
auto req_int = static_cast<nixlUcxIntReq *>(req);
req_int->setConnection(conn.get());
requests_.push_back(req_int);
nixl_status_t
append(nixl_status_t status, nixlUcxReq req, ucx_connection_ptr_t conn) {
connections_.insert(conn);
switch (status) {
case NIXL_IN_PROG:
requests_.push_back(req);
break;
case NIXL_SUCCESS:
// Nothing to do
break;
default:
// Error. Release all previously initiated ops and exit:
release();
return status;
}
return NIXL_SUCCESS;
}

const std::set<ucx_connection_ptr_t> &
getConnections() const {
return connections_;
}

virtual bool
Expand All @@ -343,7 +345,7 @@ class nixlUcxBackendH : public nixlBackendReqH {
virtual nixl_status_t
release() {
// TODO: Error log: uncompleted requests found! Cancelling ...
for (nixlUcxIntReq *req : requests_) {
for (nixlUcxReq req : requests_) {
nixl_status_t ret = ucx_status_to_nixl(ucp_request_check_status(req));
if (ret == NIXL_IN_PROG) {
// TODO: Need process this properly.
Expand All @@ -370,20 +372,19 @@ class nixlUcxBackendH : public nixlBackendReqH {

/* If last request is incomplete, return NIXL_IN_PROG early without
* checking other requests */
nixlUcxIntReq *req = requests_.back();
nixlUcxReq req = requests_.back();
nixl_status_t ret = ucx_status_to_nixl(ucp_request_check_status(req));
if (ret == NIXL_IN_PROG) {
return NIXL_IN_PROG;
} else if (ret != NIXL_SUCCESS) {
nixl_status_t conn_status = req->checkConnection(worker_id);
return (conn_status == NIXL_SUCCESS) ? ret : conn_status;
return checkConnection(ret);
}

/* Last request completed successfully, all the others must be in the
* same state. TODO: remove extra checks? */
size_t incomplete_reqs = 0;
nixl_status_t out_ret = NIXL_SUCCESS;
for (nixlUcxIntReq *req : requests_) {
for (nixlUcxReq req : requests_) {
nixl_status_t ret = ucx_status_to_nixl(ucp_request_check_status(req));
if (__builtin_expect(ret == NIXL_SUCCESS, 0)) {
worker->reqRelease(req);
Expand All @@ -394,8 +395,7 @@ class nixlUcxBackendH : public nixlBackendReqH {
requests_[incomplete_reqs++] = req;
} else {
// Any other ret value is ERR and will be returned
nixl_status_t conn_status = req->checkConnection(worker_id);
out_ret = (conn_status == NIXL_SUCCESS) ? ret : conn_status;
out_ret = checkConnection(ret);
}
}

Expand Down Expand Up @@ -1102,7 +1102,7 @@ nixlUcxEngine::nixlUcxEngine(const nixlBackendInitParams &init_params)
}

uc = std::make_unique<nixlUcxContext>(
devs, sizeof(nixlUcxIntReq), init_params.enableProgTh, num_workers, init_params.syncMode);
devs, init_params.enableProgTh, num_workers, init_params.syncMode);

for (size_t i = 0; i < num_workers; i++) {
uws.emplace_back(std::make_unique<nixlUcxWorker>(*uc, err_handling_mode));
Expand Down Expand Up @@ -1324,24 +1324,6 @@ nixl_status_t nixlUcxEngine::unloadMD (nixlBackendMD* input) {
* Data movement
*****************************************/

static nixl_status_t
_retHelper(nixl_status_t ret, nixlUcxBackendH *hndl, nixlUcxReq &req, ucx_connection_ptr_t conn) {
/* if transfer wasn't immediately completed */
switch(ret) {
case NIXL_IN_PROG:
hndl->append(req, conn);
case NIXL_SUCCESS:
// Nothing to do
break;
default:
// Error. Release all previously initiated ops and exit:
hndl->release();
return ret;
}

return NIXL_SUCCESS;
}

size_t
nixlUcxEngine::getWorkerId() const {
auto it = tlsSharedWorkerMap().find(this);
Expand Down Expand Up @@ -1461,6 +1443,56 @@ nixl_status_t nixlUcxEngine::estimateXferCost (const nixl_xfer_op_t &operation,
return NIXL_SUCCESS;
}

nixlUcxEngine::batchResult
nixlUcxEngine::sendXferRangeBatch(nixlUcxEp &ep,
nixl_xfer_op_t operation,
const nixl_meta_dlist_t &local,
const nixl_meta_dlist_t &remote,
size_t worker_id,
size_t start_idx,
size_t end_idx) {
batchResult result = {NIXL_SUCCESS, 0, nullptr};

for (size_t i = start_idx; i < end_idx; ++i) {
void *laddr = (void *)local[i].addr;
size_t lsize = local[i].len;
uint64_t raddr = (uint64_t)remote[i].addr;
NIXL_ASSERT(lsize == remote[i].len);

auto lmd = static_cast<nixlUcxPrivateMetadata *>(local[i].metadataP);
auto rmd = static_cast<nixlUcxPublicMetadata *>(remote[i].metadataP);
auto &rmd_ep = rmd->conn->getEp(worker_id);
if (__builtin_expect(rmd_ep.get() != &ep, 0)) {
break;
}

++result.size;
nixlUcxReq req;
nixl_status_t ret = operation == NIXL_READ ?
ep.read(raddr, rmd->getRkey(worker_id), laddr, lmd->mem, lsize, req) :
ep.write(laddr, lmd->mem, raddr, rmd->getRkey(worker_id), lsize, req);

if (ret == NIXL_IN_PROG) {
if (__builtin_expect(result.req != nullptr, 1)) {
ucp_request_free(result.req);
}
result.req = req;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how do we use this request? It can be returned to memory pool by UCX at any moment after the free

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As you can see we don't use freed request at all.
Instead, the idea is to keep the LAST pending (or incomplete) request. So when we detect that current request is in pending state, we free the previously stored pending request (cause now we have a more recent one), and remember the recent one.

Later on we use this last pending request in "waiting for completion" stage (checkXfer/status) in order to:

  • detect whether request completed
  • error handling
    In both cases, request is returned back to the UCX (either in status() -> worker->reqRelease() or in release() -> worker->reqRelease())

} else if (ret != NIXL_SUCCESS) {
result.status = ret;
if (result.req != nullptr) {
ucp_request_free(result.req);
result.req = nullptr;
}
break;
}
}

if (result.status == NIXL_SUCCESS && result.req) {
result.status = NIXL_IN_PROG;
}
return result;
}

nixl_status_t
nixlUcxEngine::sendXferRange(const nixl_xfer_op_t &operation,
const nixl_meta_dlist_t &local,
Expand All @@ -1470,54 +1502,44 @@ nixlUcxEngine::sendXferRange(const nixl_xfer_op_t &operation,
size_t start_idx,
size_t end_idx) const {
nixlUcxBackendH *intHandle = (nixlUcxBackendH *)handle;
nixlUcxPrivateMetadata *lmd;
nixlUcxPublicMetadata *rmd;
nixl_status_t ret;
nixlUcxReq req;
size_t workerId = intHandle->getWorkerId();
nixl_status_t ret;

// Reserve space for the requests, +2 for flush and completion
intHandle->reserve(end_idx - start_idx + 2);
if (operation != NIXL_WRITE && operation != NIXL_READ) {
return NIXL_ERR_INVALID_PARAM;
}

for (size_t i = start_idx; i < end_idx; i++) {
void *laddr = (void*) local[i].addr;
size_t lsize = local[i].len;
uint64_t raddr = (uint64_t)remote[i].addr;
size_t rsize = remote[i].len;
/* Assuming we have a single EP, we need 3 requests: one pending request,
* one flush request, and one notification request */
intHandle->reserve(3);

lmd = (nixlUcxPrivateMetadata*) local[i].metadataP;
rmd = (nixlUcxPublicMetadata*) remote[i].metadataP;
for (size_t i = start_idx; i < end_idx;) {
/* Send requests to a single EP */
auto rmd = static_cast<nixlUcxPublicMetadata *>(remote[i].metadataP);
auto &ep = rmd->conn->getEp(workerId);
auto result = sendXferRangeBatch(*ep, operation, local, remote, workerId, i, end_idx);

if (lsize != rsize) {
return NIXL_ERR_INVALID_PARAM;
}

switch (operation) {
case NIXL_READ:
ret = ep->read(raddr, rmd->getRkey(workerId), laddr, lmd->mem, lsize, req);
break;
case NIXL_WRITE:
ret = ep->write(laddr, lmd->mem, raddr, rmd->getRkey(workerId), lsize, req);
break;
default:
return NIXL_ERR_INVALID_PARAM;
}

if (_retHelper(ret, intHandle, req, rmd->conn)) {
/* Append a single pending request for the entire EP batch */
ret = intHandle->append(result.status, result.req, rmd->conn);
if (ret != NIXL_SUCCESS) {
return ret;
}

i += result.size;
}

/*
* Flush keeps intHandle non-empty until the operation is actually
* completed, which can happen after local requests completion.
* We need to flush all distinct connections to ensure that the operation
* is actually completed.
*/
rmd = (nixlUcxPublicMetadata *)remote[0].metadataP;
ret = rmd->conn->getEp(workerId)->flushEp(req);

if (_retHelper(ret, intHandle, req, rmd->conn)) {
return ret;
for (auto &conn : intHandle->getConnections()) {
nixlUcxReq req;
ret = conn->getEp(workerId)->flushEp(req);
if (intHandle->append(ret, req, conn) != NIXL_SUCCESS) {
return ret;
}
}

return NIXL_SUCCESS;
Expand Down Expand Up @@ -1557,7 +1579,7 @@ nixlUcxEngine::postXfer(const nixl_xfer_op_t &operation,
opt_args->notifMsg,
rmd->conn->getEp(int_handle->getWorkerId()),
&req);
if (_retHelper(ret, int_handle, req, rmd->conn)) {
if (int_handle->append(ret, req, rmd->conn) != NIXL_SUCCESS) {
return ret;
}

Expand Down Expand Up @@ -1594,8 +1616,8 @@ nixl_status_t nixlUcxEngine::checkXfer (nixlBackendReqH* handle) const
nixl_status_t status =
notifSendPriv(notif->agent, notif->payload, conn->getEp(intHandle->getWorkerId()), &req);
notif.reset();
status = _retHelper(status, intHandle, req, conn);
if (status != NIXL_SUCCESS) {

if (intHandle->append(status, req, conn) != NIXL_SUCCESS) {
return status;
}

Expand Down
15 changes: 15 additions & 0 deletions src/plugins/ucx/ucx_backend.h
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,21 @@ class nixlUcxEngine : public nixlBackendEngine {
ucx_connection_ptr_t
getConnection(const std::string &remote_agent) const;

struct batchResult {
nixl_status_t status;
size_t size;
nixlUcxReq req;
};

static batchResult
sendXferRangeBatch(nixlUcxEp &ep,
nixl_xfer_op_t operation,
const nixl_meta_dlist_t &local,
const nixl_meta_dlist_t &remote,
size_t worker_id,
size_t start_idx,
size_t end_idx);

/* UCX data */
std::unique_ptr<nixlUcxContext> uc;
std::vector<std::unique_ptr<nixlUcxWorker>> uws;
Expand Down
6 changes: 0 additions & 6 deletions src/utils/ucx/ucx_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,6 @@ bool nixlUcxMtLevelIsSupported(const nixl_ucx_mt_t mt_type) noexcept
}

nixlUcxContext::nixlUcxContext(std::vector<std::string> devs,
size_t req_size,
bool prog_thread,
unsigned long num_workers,
nixl_thread_sync_t sync_mode) {
Expand All @@ -429,11 +428,6 @@ nixlUcxContext::nixlUcxContext(std::vector<std::string> devs,
ucp_params.features |= UCP_FEATURE_WAKEUP;
ucp_params.mt_workers_shared = num_workers > 1 ? 1 : 0;

if (req_size) {
ucp_params.request_size = req_size;
ucp_params.field_mask |= UCP_PARAM_FIELD_REQUEST_SIZE;
}

nixl::ucx::config config;

/* If requested, restrict the set of network devices */
Expand Down
1 change: 0 additions & 1 deletion src/utils/ucx/ucx_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,6 @@ class nixlUcxContext {

public:
nixlUcxContext(std::vector<std::string> devices,
size_t req_size,
bool prog_thread,
unsigned long num_workers,
nixl_thread_sync_t sync_mode);
Expand Down
Loading