Skip to content

Commit 04e1216

Browse files
authored
NIXL/UCX: Optimize request handling (#982)
1 parent 8f8e98e commit 04e1216

File tree

6 files changed

+157
-194
lines changed

6 files changed

+157
-194
lines changed

src/plugins/ucx/ucx_backend.cpp

Lines changed: 79 additions & 113 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#include <optional>
2525
#include <limits>
2626
#include <future>
27+
#include <set>
2728
#include <string.h>
2829
#include <unistd.h>
2930
#include "absl/strings/numbers.h"
@@ -271,20 +272,12 @@ void nixlUcxEngine::vramFiniCtx()
271272

272273
class nixlUcxIntReq {
273274
public:
274-
std::unique_ptr<std::string> amBuffer;
275-
276-
bool
277-
is_complete() const {
278-
return completed_;
275+
operator nixlUcxReq() noexcept {
276+
return static_cast<nixlUcxReq>(this);
279277
}
280278

281279
void
282-
completed() {
283-
completed_ = true;
284-
}
285-
286-
void
287-
setConnection(ucx_connection_ptr_t conn) {
280+
setConnection(nixlUcxConnection *conn) {
288281
conn_ = conn;
289282
}
290283

@@ -295,51 +288,28 @@ class nixlUcxIntReq {
295288
}
296289

297290
private:
298-
bool completed_ = false;
299-
ucx_connection_ptr_t conn_;
291+
nixlUcxConnection *conn_;
300292
};
301293

302-
static void
303-
nixlUcxReqSetConnection(nixlUcxReq req, ucx_connection_ptr_t conn) {
304-
nixlUcxIntReq *req_int = reinterpret_cast<nixlUcxIntReq *>(req);
305-
req_int->setConnection(conn);
306-
}
307-
308-
static void _internalRequestInit(void *request)
309-
{
310-
/* Initialize request in-place (aka "placement new")*/
311-
new(request) nixlUcxIntReq;
312-
}
313-
314-
static void _internalRequestFini(void *request)
315-
{
316-
/* Finalize request */
317-
nixlUcxIntReq *req = (nixlUcxIntReq*)request;
318-
req->~nixlUcxIntReq();
319-
}
320-
321-
322-
static void _internalRequestReset(nixlUcxIntReq *req) {
323-
_internalRequestFini((void *)req);
324-
_internalRequestInit((void *)req);
325-
}
326-
327294
/****************************************
328295
* Backend request management
329296
*****************************************/
330297

331298
class nixlUcxBackendH : public nixlBackendReqH {
332299
private:
300+
std::set<ucx_connection_ptr_t> connections_;
333301
std::vector<nixlUcxIntReq *> requests_;
334302
nixlUcxWorker *worker;
335303
size_t worker_id;
336304

337305
// Notification to be sent after completion of all requests
338306
struct Notif {
339-
std::string agent;
340-
nixl_blob_t payload;
341-
Notif(const std::string& remote_agent, const nixl_blob_t& msg)
342-
: agent(remote_agent), payload(msg) {}
307+
std::string agent;
308+
nixl_blob_t payload;
309+
310+
Notif(const std::string &remote_agent, const nixl_blob_t &msg)
311+
: agent(remote_agent),
312+
payload(msg) {}
343313
};
344314
std::optional<Notif> notif;
345315

@@ -358,8 +328,11 @@ class nixlUcxBackendH : public nixlBackendReqH {
358328
}
359329

360330
void
361-
append(nixlUcxIntReq *req) {
362-
requests_.push_back(req);
331+
append(nixlUcxReq req, ucx_connection_ptr_t conn) {
332+
auto req_int = static_cast<nixlUcxIntReq *>(req);
333+
req_int->setConnection(conn.get());
334+
requests_.push_back(req_int);
335+
connections_.insert(conn);
363336
}
364337

365338
virtual bool
@@ -371,15 +344,16 @@ class nixlUcxBackendH : public nixlBackendReqH {
371344
release() {
372345
// TODO: Error log: uncompleted requests found! Cancelling ...
373346
for (nixlUcxIntReq *req : requests_) {
374-
if (!req->is_complete()) {
347+
nixl_status_t ret = ucx_status_to_nixl(ucp_request_check_status(req));
348+
if (ret == NIXL_IN_PROG) {
375349
// TODO: Need process this properly.
376350
// it may not be enough to cancel UCX request
377-
worker->reqCancel((nixlUcxReq)req);
351+
worker->reqCancel(req);
378352
}
379-
_internalRequestReset(req);
380-
worker->reqRelease((nixlUcxReq)req);
353+
worker->reqRelease(req);
381354
}
382355
requests_.clear();
356+
connections_.clear();
383357
return NIXL_SUCCESS;
384358
}
385359

@@ -394,37 +368,37 @@ class nixlUcxBackendH : public nixlBackendReqH {
394368
while (worker->progress())
395369
;
396370

397-
/* Go over all request updating their status */
398-
nixl_status_t out_ret = NIXL_SUCCESS;
399-
for (nixlUcxIntReq *req : requests_) {
400-
nixl_status_t ret;
401-
if (!req->is_complete()) {
402-
ret = ucx_status_to_nixl(ucp_request_check_status((nixlUcxReq)req));
403-
switch (ret) {
404-
case NIXL_SUCCESS:
405-
/* Mark as completed */
406-
req->completed();
407-
break;
408-
case NIXL_IN_PROG:
409-
out_ret = NIXL_IN_PROG;
410-
break;
411-
default:
412-
// Any other ret value is ERR and will be returned
413-
nixl_status_t conn_status = req->checkConnection(worker_id);
414-
return (conn_status == NIXL_SUCCESS) ? ret : conn_status;
415-
}
416-
}
371+
/* If last request is incomplete, return NIXL_IN_PROG early without
372+
* checking other requests */
373+
nixlUcxIntReq *req = requests_.back();
374+
nixl_status_t ret = ucx_status_to_nixl(ucp_request_check_status(req));
375+
if (ret == NIXL_IN_PROG) {
376+
return NIXL_IN_PROG;
377+
} else if (ret != NIXL_SUCCESS) {
378+
nixl_status_t conn_status = req->checkConnection(worker_id);
379+
return (conn_status == NIXL_SUCCESS) ? ret : conn_status;
417380
}
418381

382+
/* Last request completed successfully, all the others must be in the
383+
* same state. TODO: remove extra checks? */
419384
size_t incomplete_reqs = 0;
385+
nixl_status_t out_ret = NIXL_SUCCESS;
420386
for (nixlUcxIntReq *req : requests_) {
421-
if (req->is_complete()) {
422-
_internalRequestReset(req);
423-
worker->reqRelease((nixlUcxReq)req);
424-
} else {
387+
nixl_status_t ret = ucx_status_to_nixl(ucp_request_check_status(req));
388+
if (__builtin_expect(ret == NIXL_SUCCESS, 0)) {
389+
worker->reqRelease(req);
390+
} else if (ret == NIXL_IN_PROG) {
391+
if (out_ret == NIXL_SUCCESS) {
392+
out_ret = NIXL_IN_PROG;
393+
}
425394
requests_[incomplete_reqs++] = req;
395+
} else {
396+
// Any other ret value is ERR and will be returned
397+
nixl_status_t conn_status = req->checkConnection(worker_id);
398+
out_ret = (conn_status == NIXL_SUCCESS) ? ret : conn_status;
426399
}
427400
}
401+
428402
requests_.resize(incomplete_reqs);
429403
return out_ret;
430404
}
@@ -1127,13 +1101,8 @@ nixlUcxEngine::nixlUcxEngine(const nixlBackendInitParams &init_params)
11271101
err_handling_mode = ucx_err_mode_from_string(err_handling_mode_it->second);
11281102
}
11291103

1130-
uc = std::make_unique<nixlUcxContext>(devs,
1131-
sizeof(nixlUcxIntReq),
1132-
_internalRequestInit,
1133-
_internalRequestFini,
1134-
init_params.enableProgTh,
1135-
num_workers,
1136-
init_params.syncMode);
1104+
uc = std::make_unique<nixlUcxContext>(
1105+
devs, sizeof(nixlUcxIntReq), init_params.enableProgTh, num_workers, init_params.syncMode);
11371106

11381107
for (size_t i = 0; i < num_workers; i++) {
11391108
uws.emplace_back(std::make_unique<nixlUcxWorker>(*uc, err_handling_mode));
@@ -1360,10 +1329,7 @@ _retHelper(nixl_status_t ret, nixlUcxBackendH *hndl, nixlUcxReq &req, ucx_connec
13601329
/* if transfer wasn't immediately completed */
13611330
switch(ret) {
13621331
case NIXL_IN_PROG:
1363-
// TODO: this cast does not look safe
1364-
// We need to allocate a vector of nixlUcxIntReq and set nixlUcxReqt
1365-
hndl->append((nixlUcxIntReq *)req);
1366-
nixlUcxReqSetConnection(req, conn);
1332+
hndl->append(req, conn);
13671333
case NIXL_SUCCESS:
13681334
// Nothing to do
13691335
break;
@@ -1587,8 +1553,10 @@ nixlUcxEngine::postXfer(const nixl_xfer_op_t &operation,
15871553
if (ret == NIXL_SUCCESS) {
15881554
nixlUcxReq req;
15891555
auto rmd = (nixlUcxPublicMetadata *)remote[0].metadataP;
1590-
ret = notifSendPriv(
1591-
remote_agent, opt_args->notifMsg, req, rmd->conn->getEp(int_handle->getWorkerId()));
1556+
ret = notifSendPriv(remote_agent,
1557+
opt_args->notifMsg,
1558+
rmd->conn->getEp(int_handle->getWorkerId()),
1559+
&req);
15921560
if (_retHelper(ret, int_handle, req, rmd->conn)) {
15931561
return ret;
15941562
}
@@ -1624,7 +1592,7 @@ nixl_status_t nixlUcxEngine::checkXfer (nixlBackendReqH* handle) const
16241592

16251593
nixlUcxReq req;
16261594
nixl_status_t status =
1627-
notifSendPriv(notif->agent, notif->payload, req, conn->getEp(intHandle->getWorkerId()));
1595+
notifSendPriv(notif->agent, notif->payload, conn->getEp(intHandle->getWorkerId()), &req);
16281596
notif.reset();
16291597
status = _retHelper(status, intHandle, req, conn);
16301598
if (status != NIXL_SUCCESS) {
@@ -1759,23 +1727,31 @@ int nixlUcxEngine::progress() {
17591727
nixl_status_t
17601728
nixlUcxEngine::notifSendPriv(const std::string &remote_agent,
17611729
const std::string &msg,
1762-
nixlUcxReq &req,
1763-
const std::unique_ptr<nixlUcxEp> &ep) const {
1730+
const std::unique_ptr<nixlUcxEp> &ep,
1731+
nixlUcxReq *req) const {
17641732
nixlSerDes ser_des;
1765-
nixl_status_t ret;
17661733

17671734
ser_des.addStr("name", localAgent);
17681735
ser_des.addStr("msg", msg);
17691736
// TODO: replace with mpool for performance
17701737

1771-
auto buffer = std::make_unique<std::string>(ser_des.exportStr());
1772-
ret = ep->sendAm(
1773-
NOTIF_STR, NULL, 0, (void *)buffer->data(), buffer->size(), UCP_AM_SEND_FLAG_EAGER, req);
1774-
if (ret == NIXL_IN_PROG) {
1775-
nixlUcxIntReq* nReq = (nixlUcxIntReq*)req;
1776-
nReq->amBuffer = std::move(buffer);
1777-
}
1778-
return ret;
1738+
std::string *buffer = new std::string(ser_des.exportStr());
1739+
auto deleter = [buffer, req](void *completed_request, void *ptr) {
1740+
delete buffer;
1741+
if ((req == nullptr) && (completed_request != nullptr)) {
1742+
/* Caller is not interested in the request, free it */
1743+
ucp_request_free(completed_request);
1744+
}
1745+
};
1746+
1747+
return ep->sendAm(NOTIF_STR,
1748+
nullptr,
1749+
0,
1750+
(void *)buffer->data(),
1751+
buffer->size(),
1752+
UCP_AM_SEND_FLAG_EAGER,
1753+
req,
1754+
deleter);
17791755
}
17801756

17811757
ucx_connection_ptr_t
@@ -1827,26 +1803,16 @@ nixl_status_t nixlUcxEngine::getNotifs(notif_list_t &notif_list)
18271803
return NIXL_SUCCESS;
18281804
}
18291805

1830-
nixl_status_t nixlUcxEngine::genNotif(const std::string &remote_agent, const std::string &msg) const
1831-
{
1832-
nixl_status_t ret;
1833-
nixlUcxReq req;
1834-
1806+
nixl_status_t
1807+
nixlUcxEngine::genNotif(const std::string &remote_agent, const std::string &msg) const {
18351808
auto conn = getConnection(remote_agent);
18361809
if (!conn) {
18371810
return NIXL_ERR_NOT_FOUND;
18381811
}
18391812

1840-
ret = notifSendPriv(remote_agent, msg, req, conn->getEp(getWorkerId()));
1841-
switch(ret) {
1842-
case NIXL_IN_PROG:
1843-
/* do not track the request */
1844-
getWorker(getWorkerId())->reqRelease(req);
1845-
case NIXL_SUCCESS:
1846-
break;
1847-
default:
1848-
/* error case */
1849-
return ret;
1813+
nixl_status_t ret = notifSendPriv(remote_agent, msg, conn->getEp(getWorkerId()));
1814+
if (ret == NIXL_IN_PROG) {
1815+
ret = NIXL_SUCCESS;
18501816
}
1851-
return NIXL_SUCCESS;
1817+
return ret;
18521818
}

src/plugins/ucx/ucx_backend.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -288,8 +288,8 @@ class nixlUcxEngine : public nixlBackendEngine {
288288
nixl_status_t
289289
notifSendPriv(const std::string &remote_agent,
290290
const std::string &msg,
291-
nixlUcxReq &req,
292-
const std::unique_ptr<nixlUcxEp> &ep) const;
291+
const std::unique_ptr<nixlUcxEp> &ep,
292+
nixlUcxReq *req = nullptr) const;
293293

294294
ucx_connection_ptr_t
295295
getConnection(const std::string &remote_agent) const;

0 commit comments

Comments
 (0)