2424#include < optional>
2525#include < limits>
2626#include < future>
27+ #include < set>
2728#include < string.h>
2829#include < unistd.h>
2930#include " absl/strings/numbers.h"
@@ -271,20 +272,12 @@ void nixlUcxEngine::vramFiniCtx()
271272
272273class nixlUcxIntReq {
273274public:
274- std::unique_ptr<std::string> amBuffer;
275-
276- bool
277- is_complete () const {
278- return completed_;
275+ operator nixlUcxReq () noexcept {
276+ return static_cast <nixlUcxReq>(this );
279277 }
280278
281279 void
282- completed () {
283- completed_ = true ;
284- }
285-
286- void
287- setConnection (ucx_connection_ptr_t conn) {
280+ setConnection (nixlUcxConnection *conn) {
288281 conn_ = conn;
289282 }
290283
@@ -295,51 +288,28 @@ class nixlUcxIntReq {
295288 }
296289
297290private:
298- bool completed_ = false ;
299- ucx_connection_ptr_t conn_;
291+ nixlUcxConnection *conn_;
300292};
301293
302- static void
303- nixlUcxReqSetConnection (nixlUcxReq req, ucx_connection_ptr_t conn) {
304- nixlUcxIntReq *req_int = reinterpret_cast <nixlUcxIntReq *>(req);
305- req_int->setConnection (conn);
306- }
307-
308- static void _internalRequestInit (void *request)
309- {
310- /* Initialize request in-place (aka "placement new")*/
311- new (request) nixlUcxIntReq;
312- }
313-
314- static void _internalRequestFini (void *request)
315- {
316- /* Finalize request */
317- nixlUcxIntReq *req = (nixlUcxIntReq*)request;
318- req->~nixlUcxIntReq ();
319- }
320-
321-
322- static void _internalRequestReset (nixlUcxIntReq *req) {
323- _internalRequestFini ((void *)req);
324- _internalRequestInit ((void *)req);
325- }
326-
327294/* ***************************************
328295 * Backend request management
329296*****************************************/
330297
331298class nixlUcxBackendH : public nixlBackendReqH {
332299private:
300+ std::set<ucx_connection_ptr_t > connections_;
333301 std::vector<nixlUcxIntReq *> requests_;
334302 nixlUcxWorker *worker;
335303 size_t worker_id;
336304
337305 // Notification to be sent after completion of all requests
338306 struct Notif {
339- std::string agent;
340- nixl_blob_t payload;
341- Notif (const std::string& remote_agent, const nixl_blob_t & msg)
342- : agent(remote_agent), payload(msg) {}
307+ std::string agent;
308+ nixl_blob_t payload;
309+
310+ Notif (const std::string &remote_agent, const nixl_blob_t &msg)
311+ : agent(remote_agent),
312+ payload (msg) {}
343313 };
344314 std::optional<Notif> notif;
345315
@@ -358,8 +328,11 @@ class nixlUcxBackendH : public nixlBackendReqH {
358328 }
359329
360330 void
361- append (nixlUcxIntReq *req) {
362- requests_.push_back (req);
331+ append (nixlUcxReq req, ucx_connection_ptr_t conn) {
332+ auto req_int = static_cast <nixlUcxIntReq *>(req);
333+ req_int->setConnection (conn.get ());
334+ requests_.push_back (req_int);
335+ connections_.insert (conn);
363336 }
364337
365338 virtual bool
@@ -371,15 +344,16 @@ class nixlUcxBackendH : public nixlBackendReqH {
371344 release () {
372345 // TODO: Error log: uncompleted requests found! Cancelling ...
373346 for (nixlUcxIntReq *req : requests_) {
374- if (!req->is_complete ()) {
347+ nixl_status_t ret = ucx_status_to_nixl (ucp_request_check_status (req));
348+ if (ret == NIXL_IN_PROG) {
375349 // TODO: Need process this properly.
376350 // it may not be enough to cancel UCX request
377- worker->reqCancel ((nixlUcxReq) req);
351+ worker->reqCancel (req);
378352 }
379- _internalRequestReset (req);
380- worker->reqRelease ((nixlUcxReq)req);
353+ worker->reqRelease (req);
381354 }
382355 requests_.clear ();
356+ connections_.clear ();
383357 return NIXL_SUCCESS;
384358 }
385359
@@ -394,37 +368,37 @@ class nixlUcxBackendH : public nixlBackendReqH {
394368 while (worker->progress ())
395369 ;
396370
397- /* Go over all request updating their status */
398- nixl_status_t out_ret = NIXL_SUCCESS;
399- for (nixlUcxIntReq *req : requests_) {
400- nixl_status_t ret;
401- if (!req->is_complete ()) {
402- ret = ucx_status_to_nixl (ucp_request_check_status ((nixlUcxReq)req));
403- switch (ret) {
404- case NIXL_SUCCESS:
405- /* Mark as completed */
406- req->completed ();
407- break ;
408- case NIXL_IN_PROG:
409- out_ret = NIXL_IN_PROG;
410- break ;
411- default :
412- // Any other ret value is ERR and will be returned
413- nixl_status_t conn_status = req->checkConnection (worker_id);
414- return (conn_status == NIXL_SUCCESS) ? ret : conn_status;
415- }
416- }
371+ /* If last request is incomplete, return NIXL_IN_PROG early without
372+ * checking other requests */
373+ nixlUcxIntReq *req = requests_.back ();
374+ nixl_status_t ret = ucx_status_to_nixl (ucp_request_check_status (req));
375+ if (ret == NIXL_IN_PROG) {
376+ return NIXL_IN_PROG;
377+ } else if (ret != NIXL_SUCCESS) {
378+ nixl_status_t conn_status = req->checkConnection (worker_id);
379+ return (conn_status == NIXL_SUCCESS) ? ret : conn_status;
417380 }
418381
382+ /* Last request completed successfully, all the others must be in the
383+ * same state. TODO: remove extra checks? */
419384 size_t incomplete_reqs = 0 ;
385+ nixl_status_t out_ret = NIXL_SUCCESS;
420386 for (nixlUcxIntReq *req : requests_) {
421- if (req->is_complete ()) {
422- _internalRequestReset (req);
423- worker->reqRelease ((nixlUcxReq)req);
424- } else {
387+ nixl_status_t ret = ucx_status_to_nixl (ucp_request_check_status (req));
388+ if (__builtin_expect (ret == NIXL_SUCCESS, 0 )) {
389+ worker->reqRelease (req);
390+ } else if (ret == NIXL_IN_PROG) {
391+ if (out_ret == NIXL_SUCCESS) {
392+ out_ret = NIXL_IN_PROG;
393+ }
425394 requests_[incomplete_reqs++] = req;
395+ } else {
396+ // Any other ret value is ERR and will be returned
397+ nixl_status_t conn_status = req->checkConnection (worker_id);
398+ out_ret = (conn_status == NIXL_SUCCESS) ? ret : conn_status;
426399 }
427400 }
401+
428402 requests_.resize (incomplete_reqs);
429403 return out_ret;
430404 }
@@ -1127,13 +1101,8 @@ nixlUcxEngine::nixlUcxEngine(const nixlBackendInitParams &init_params)
11271101 err_handling_mode = ucx_err_mode_from_string (err_handling_mode_it->second );
11281102 }
11291103
1130- uc = std::make_unique<nixlUcxContext>(devs,
1131- sizeof (nixlUcxIntReq),
1132- _internalRequestInit,
1133- _internalRequestFini,
1134- init_params.enableProgTh ,
1135- num_workers,
1136- init_params.syncMode );
1104+ uc = std::make_unique<nixlUcxContext>(
1105+ devs, sizeof (nixlUcxIntReq), init_params.enableProgTh , num_workers, init_params.syncMode );
11371106
11381107 for (size_t i = 0 ; i < num_workers; i++) {
11391108 uws.emplace_back (std::make_unique<nixlUcxWorker>(*uc, err_handling_mode));
@@ -1360,10 +1329,7 @@ _retHelper(nixl_status_t ret, nixlUcxBackendH *hndl, nixlUcxReq &req, ucx_connec
13601329 /* if transfer wasn't immediately completed */
13611330 switch (ret) {
13621331 case NIXL_IN_PROG:
1363- // TODO: this cast does not look safe
1364- // We need to allocate a vector of nixlUcxIntReq and set nixlUcxReqt
1365- hndl->append ((nixlUcxIntReq *)req);
1366- nixlUcxReqSetConnection (req, conn);
1332+ hndl->append (req, conn);
13671333 case NIXL_SUCCESS:
13681334 // Nothing to do
13691335 break ;
@@ -1587,8 +1553,10 @@ nixlUcxEngine::postXfer(const nixl_xfer_op_t &operation,
15871553 if (ret == NIXL_SUCCESS) {
15881554 nixlUcxReq req;
15891555 auto rmd = (nixlUcxPublicMetadata *)remote[0 ].metadataP ;
1590- ret = notifSendPriv (
1591- remote_agent, opt_args->notifMsg , req, rmd->conn ->getEp (int_handle->getWorkerId ()));
1556+ ret = notifSendPriv (remote_agent,
1557+ opt_args->notifMsg ,
1558+ rmd->conn ->getEp (int_handle->getWorkerId ()),
1559+ &req);
15921560 if (_retHelper (ret, int_handle, req, rmd->conn )) {
15931561 return ret;
15941562 }
@@ -1624,7 +1592,7 @@ nixl_status_t nixlUcxEngine::checkXfer (nixlBackendReqH* handle) const
16241592
16251593 nixlUcxReq req;
16261594 nixl_status_t status =
1627- notifSendPriv (notif->agent , notif->payload , req, conn->getEp (intHandle->getWorkerId ()));
1595+ notifSendPriv (notif->agent , notif->payload , conn->getEp (intHandle->getWorkerId ()), &req );
16281596 notif.reset ();
16291597 status = _retHelper (status, intHandle, req, conn);
16301598 if (status != NIXL_SUCCESS) {
@@ -1759,23 +1727,31 @@ int nixlUcxEngine::progress() {
17591727nixl_status_t
17601728nixlUcxEngine::notifSendPriv (const std::string &remote_agent,
17611729 const std::string &msg,
1762- nixlUcxReq &req ,
1763- const std::unique_ptr<nixlUcxEp> &ep ) const {
1730+ const std::unique_ptr<nixlUcxEp> &ep ,
1731+ nixlUcxReq *req ) const {
17641732 nixlSerDes ser_des;
1765- nixl_status_t ret;
17661733
17671734 ser_des.addStr (" name" , localAgent);
17681735 ser_des.addStr (" msg" , msg);
17691736 // TODO: replace with mpool for performance
17701737
1771- auto buffer = std::make_unique<std::string>(ser_des.exportStr ());
1772- ret = ep->sendAm (
1773- NOTIF_STR, NULL , 0 , (void *)buffer->data (), buffer->size (), UCP_AM_SEND_FLAG_EAGER, req);
1774- if (ret == NIXL_IN_PROG) {
1775- nixlUcxIntReq* nReq = (nixlUcxIntReq*)req;
1776- nReq->amBuffer = std::move (buffer);
1777- }
1778- return ret;
1738+ std::string *buffer = new std::string (ser_des.exportStr ());
1739+ auto deleter = [buffer, req](void *completed_request, void *ptr) {
1740+ delete buffer;
1741+ if ((req == nullptr ) && (completed_request != nullptr )) {
1742+ /* Caller is not interested in the request, free it */
1743+ ucp_request_free (completed_request);
1744+ }
1745+ };
1746+
1747+ return ep->sendAm (NOTIF_STR,
1748+ nullptr ,
1749+ 0 ,
1750+ (void *)buffer->data (),
1751+ buffer->size (),
1752+ UCP_AM_SEND_FLAG_EAGER,
1753+ req,
1754+ deleter);
17791755}
17801756
17811757ucx_connection_ptr_t
@@ -1827,26 +1803,16 @@ nixl_status_t nixlUcxEngine::getNotifs(notif_list_t ¬if_list)
18271803 return NIXL_SUCCESS;
18281804}
18291805
1830- nixl_status_t nixlUcxEngine::genNotif (const std::string &remote_agent, const std::string &msg) const
1831- {
1832- nixl_status_t ret;
1833- nixlUcxReq req;
1834-
1806+ nixl_status_t
1807+ nixlUcxEngine::genNotif (const std::string &remote_agent, const std::string &msg) const {
18351808 auto conn = getConnection (remote_agent);
18361809 if (!conn) {
18371810 return NIXL_ERR_NOT_FOUND;
18381811 }
18391812
1840- ret = notifSendPriv (remote_agent, msg, req, conn->getEp (getWorkerId ()));
1841- switch (ret) {
1842- case NIXL_IN_PROG:
1843- /* do not track the request */
1844- getWorker (getWorkerId ())->reqRelease (req);
1845- case NIXL_SUCCESS:
1846- break ;
1847- default :
1848- /* error case */
1849- return ret;
1813+ nixl_status_t ret = notifSendPriv (remote_agent, msg, conn->getEp (getWorkerId ()));
1814+ if (ret == NIXL_IN_PROG) {
1815+ ret = NIXL_SUCCESS;
18501816 }
1851- return NIXL_SUCCESS ;
1817+ return ret ;
18521818}
0 commit comments