Skip to content

Commit bb0b873

Browse files
UCX/BACKEND: Add worker_id selection support (#938)
Signed-off-by: Michal Shalev <[email protected]>
1 parent 0e55b49 commit bb0b873

File tree

5 files changed

+185
-18
lines changed

5 files changed

+185
-18
lines changed

src/api/cpp/backend/backend_engine.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,9 @@ class nixlBackendEngine {
177177

178178
// Initialize a signal for GPU transfer using memory handle from descriptor
179179
virtual nixl_status_t
180-
prepGpuSignal(const nixlBackendMD &meta, void *signal) const {
180+
prepGpuSignal(const nixlBackendMD &meta,
181+
void *signal,
182+
const nixl_opt_b_args_t *opt_args = nullptr) const {
181183
return NIXL_ERR_NOT_SUPPORTED;
182184
}
183185

src/plugins/ucx/ucx_backend.cpp

Lines changed: 44 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1388,6 +1388,35 @@ nixlUcxEngine::getWorkerId() const {
13881388
return it->second;
13891389
}
13901390

1391+
std::optional<size_t>
1392+
nixlUcxEngine::getWorkerIdFromOptArgs(const nixl_opt_b_args_t *opt_args) const noexcept {
1393+
if (!opt_args || opt_args->customParam.empty()) {
1394+
return std::nullopt;
1395+
}
1396+
1397+
constexpr std::string_view worker_id_key = "worker_id=";
1398+
size_t pos = opt_args->customParam.find(worker_id_key);
1399+
if (pos == std::string::npos) {
1400+
return std::nullopt;
1401+
}
1402+
1403+
try {
1404+
size_t worker_id = std::stoull(opt_args->customParam.substr(pos + worker_id_key.length()));
1405+
1406+
if (worker_id >= getSharedWorkersSize()) {
1407+
NIXL_WARN << "Invalid worker_id " << worker_id << " (must be < "
1408+
<< getSharedWorkersSize() << ")";
1409+
return std::nullopt;
1410+
}
1411+
1412+
return worker_id;
1413+
}
1414+
catch (const std::exception &e) {
1415+
NIXL_WARN << "Failed to parse worker_id from customParam: " << e.what();
1416+
return std::nullopt;
1417+
}
1418+
}
1419+
13911420
nixl_status_t nixlUcxEngine::prepXfer (const nixl_xfer_op_t &operation,
13921421
const nixl_meta_dlist_t &local,
13931422
const nixl_meta_dlist_t &remote,
@@ -1401,7 +1430,8 @@ nixl_status_t nixlUcxEngine::prepXfer (const nixl_xfer_op_t &operation,
14011430
}
14021431

14031432
/* TODO: try to get from a pool first */
1404-
size_t worker_id = getWorkerId();
1433+
const auto opt_worker_id = getWorkerIdFromOptArgs(opt_args);
1434+
size_t worker_id = opt_worker_id.value_or(getWorkerId());
14051435
auto *ucx_handle = new nixlUcxBackendH(getWorker(worker_id).get(), worker_id);
14061436

14071437
handle = ucx_handle;
@@ -1659,6 +1689,8 @@ nixlUcxEngine::createGpuXferReq(const nixlBackendReqH &req_hndl,
16591689

16601690
try {
16611691
gpu_req_hndl = nixl::ucx::createGpuXferReq(*ep, local_mems, remote_rkeys, remote_addrs);
1692+
NIXL_TRACE << "Created device memory list: ep=" << ep->getEp() << " handle=" << gpu_req_hndl
1693+
<< " worker_id=" << workerId << " num_elements=" << local_mems.size();
16621694
return NIXL_SUCCESS;
16631695
}
16641696
catch (const std::exception &e) {
@@ -1690,10 +1722,19 @@ nixlUcxEngine::getGpuSignalSize(size_t &signal_size) const {
16901722
}
16911723

16921724
nixl_status_t
1693-
nixlUcxEngine::prepGpuSignal(const nixlBackendMD &meta, void *signal) const {
1725+
nixlUcxEngine::prepGpuSignal(const nixlBackendMD &meta,
1726+
void *signal,
1727+
const nixl_opt_b_args_t *opt_args) const {
16941728
try {
16951729
auto *ucx_meta = static_cast<const nixlUcxPrivateMetadata *>(&meta);
1696-
getWorker(getWorkerId())->prepGpuSignal(ucx_meta->mem, signal);
1730+
1731+
const auto opt_worker_id = getWorkerIdFromOptArgs(opt_args);
1732+
if (opt_worker_id) {
1733+
getWorker(*opt_worker_id)->prepGpuSignal(ucx_meta->mem, signal);
1734+
} else {
1735+
getWorker(getWorkerId())->prepGpuSignal(ucx_meta->mem, signal);
1736+
}
1737+
16971738
return NIXL_SUCCESS;
16981739
}
16991740
catch (const std::exception &e) {

src/plugins/ucx/ucx_backend.h

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,9 @@ class nixlUcxEngine : public nixlBackendEngine {
204204
getGpuSignalSize(size_t &signal_size) const override;
205205

206206
nixl_status_t
207-
prepGpuSignal(const nixlBackendMD &meta, void *signal) const override;
207+
prepGpuSignal(const nixlBackendMD &meta,
208+
void *signal,
209+
const nixl_opt_b_args_t *opt_args = nullptr) const override;
208210

209211
int
210212
progress();
@@ -218,6 +220,11 @@ class nixlUcxEngine : public nixlBackendEngine {
218220
nixl_status_t
219221
checkConn(const std::string &remote_agent);
220222

223+
private:
224+
// Helper to extract worker_id from opt_args->customParam or nullopt if not found
225+
[[nodiscard]] std::optional<size_t>
226+
getWorkerIdFromOptArgs(const nixl_opt_b_args_t *opt_args) const noexcept;
227+
221228
protected:
222229
const std::vector<std::unique_ptr<nixlUcxWorker>> &
223230
getWorkers() const {

src/utils/ucx/gpu_xfer_req_h.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,8 @@ createGpuXferReq(const nixlUcxEp &ep,
8181
ucs_status_string(ucs_status));
8282
}
8383

84-
NIXL_DEBUG << "Created device memory list handle with " << local_mems.size() << " elements";
84+
NIXL_DEBUG << "Created device memory list: ep=" << ep.getEp() << " handle=" << ucx_handle
85+
<< " num_elements=" << local_mems.size();
8586
return reinterpret_cast<nixlGpuXferReqH>(ucx_handle);
8687
}
8788

test/gtest/device_api/single_write_test.cu

Lines changed: 128 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ protected:
149149
nixl_b_params_t params;
150150

151151
if (getBackendName() == "UCX") {
152-
params["num_workers"] = "2";
152+
params["num_workers"] = std::to_string(numWorkers);
153153
}
154154

155155
return params;
@@ -194,21 +194,50 @@ protected:
194194
agent.registerMem(reg_list);
195195
}
196196

197+
// TODO: remove this function once a blocking CreateGpuXferReq is implemented
197198
void
198-
completeWireup(size_t from_agent, size_t to_agent) {
199-
nixl_notifs_t notifs;
200-
nixl_status_t status = getAgent(from_agent).genNotif(getAgentName(to_agent), NOTIF_MSG);
201-
ASSERT_EQ(status, NIXL_SUCCESS) << "Failed to complete wireup";
202-
203-
do {
204-
nixl_status_t ret = getAgent(to_agent).getNotifs(notifs);
205-
ASSERT_EQ(ret, NIXL_SUCCESS) << "Failed to get notifications during wireup";
206-
std::this_thread::sleep_for(std::chrono::milliseconds(10));
207-
} while (notifs.size() == 0);
199+
completeWireup(size_t from_agent, size_t to_agent,
200+
const std::vector<MemBuffer> &wireup_src,
201+
const std::vector<MemBuffer> &wireup_dst) {
202+
nixl_opt_args_t wireup_params;
203+
204+
for (size_t worker_id = 0; worker_id < numWorkers; worker_id++) {
205+
wireup_params.customParam = "worker_id=" + std::to_string(worker_id);
206+
207+
nixlXferReqH *wireup_req;
208+
nixl_status_t status = getAgent(from_agent)
209+
.createXferReq(NIXL_WRITE,
210+
makeDescList<nixlBasicDesc>(wireup_src, VRAM_SEG),
211+
makeDescList<nixlBasicDesc>(wireup_dst, VRAM_SEG),
212+
getAgentName(to_agent),
213+
wireup_req,
214+
&wireup_params);
215+
216+
ASSERT_EQ(status, NIXL_SUCCESS) << "Failed to create wireup request for worker " << worker_id;
217+
218+
status = getAgent(from_agent).postXferReq(wireup_req);
219+
ASSERT_TRUE(status == NIXL_SUCCESS || status == NIXL_IN_PROG)
220+
<< "Failed to post wireup for worker " << worker_id;
221+
222+
nixl_status_t xfer_status;
223+
do {
224+
xfer_status = getAgent(from_agent).getXferStatus(wireup_req);
225+
std::this_thread::sleep_for(std::chrono::milliseconds(1));
226+
} while (xfer_status == NIXL_IN_PROG);
227+
228+
ASSERT_EQ(xfer_status, NIXL_SUCCESS) << "Warmup failed for worker " << worker_id;
229+
230+
status = getAgent(from_agent).releaseXferReq(wireup_req);
231+
ASSERT_EQ(status, NIXL_SUCCESS);
232+
}
208233
}
209234

210235
void
211236
exchangeMD(size_t from_agent, size_t to_agent) {
237+
std::vector<MemBuffer> wireup_src, wireup_dst;
238+
createRegisteredMem(getAgent(from_agent), 64, 1, VRAM_SEG, wireup_src);
239+
createRegisteredMem(getAgent(to_agent), 64, 1, VRAM_SEG, wireup_dst);
240+
212241
for (size_t i = 0; i < agents.size(); i++) {
213242
nixl_blob_t md;
214243
nixl_status_t status = agents[i]->getLocalMD(md);
@@ -223,7 +252,7 @@ protected:
223252
}
224253
}
225254

226-
completeWireup(from_agent, to_agent);
255+
completeWireup(from_agent, to_agent, wireup_src, wireup_dst);
227256
}
228257

229258
void
@@ -316,6 +345,7 @@ protected:
316345
protected:
317346
static constexpr size_t SENDER_AGENT = 0;
318347
static constexpr size_t RECEIVER_AGENT = 1;
348+
static constexpr size_t numWorkers = 32;
319349

320350
private:
321351
static constexpr uint64_t DEV_ID = 0;
@@ -572,6 +602,92 @@ TEST_P(SingleWriteTest, VariableSizeTest) {
572602
}
573603
}
574604

605+
TEST_P(SingleWriteTest, MultipleWorkersTest) {
606+
constexpr size_t size = 4 * 1024;
607+
constexpr size_t num_iters = 100;
608+
constexpr unsigned index = 0;
609+
constexpr bool is_no_delay = true;
610+
constexpr nixl_mem_t mem_type = VRAM_SEG;
611+
constexpr size_t num_threads = 32;
612+
613+
std::vector<std::vector<MemBuffer>> src_buffers(numWorkers);
614+
std::vector<std::vector<MemBuffer>> dst_buffers(numWorkers);
615+
std::vector<std::vector<uint32_t>> patterns(numWorkers);
616+
617+
for (size_t worker_id = 0; worker_id < numWorkers; worker_id++) {
618+
createRegisteredMem(getAgent(SENDER_AGENT), size, 1, mem_type, src_buffers[worker_id]);
619+
createRegisteredMem(getAgent(RECEIVER_AGENT), size, 1, mem_type, dst_buffers[worker_id]);
620+
621+
constexpr size_t num_elements = size / sizeof(uint32_t);
622+
patterns[worker_id].resize(num_elements);
623+
for (size_t i = 0; i < num_elements; i++) {
624+
patterns[worker_id][i] = 0xDEAD0000 | worker_id;
625+
}
626+
cudaMemcpy(static_cast<void *>(src_buffers[worker_id][0]), patterns[worker_id].data(),
627+
size, cudaMemcpyHostToDevice);
628+
}
629+
630+
exchangeMD(SENDER_AGENT, RECEIVER_AGENT);
631+
632+
nixl_opt_args_t extra_params = {};
633+
extra_params.hasNotif = true;
634+
extra_params.notifMsg = NOTIF_MSG;
635+
636+
std::vector<nixlXferReqH *> xfer_reqs(numWorkers);
637+
std::vector<nixlGpuXferReqH> gpu_req_hndls(numWorkers);
638+
639+
for (size_t worker_id = 0; worker_id < numWorkers; worker_id++) {
640+
extra_params.customParam = "worker_id=" + std::to_string(worker_id);
641+
642+
nixl_status_t status = getAgent(SENDER_AGENT)
643+
.createXferReq(NIXL_WRITE,
644+
makeDescList<nixlBasicDesc>(src_buffers[worker_id], mem_type),
645+
makeDescList<nixlBasicDesc>(dst_buffers[worker_id], mem_type),
646+
getAgentName(RECEIVER_AGENT),
647+
xfer_reqs[worker_id],
648+
&extra_params);
649+
650+
ASSERT_EQ(status, NIXL_SUCCESS) << "Failed to create xfer request for worker " << worker_id;
651+
652+
status = getAgent(SENDER_AGENT).createGpuXferReq(*xfer_reqs[worker_id], gpu_req_hndls[worker_id]);
653+
ASSERT_EQ(status, NIXL_SUCCESS) << "Failed to create GPU xfer request for worker " << worker_id;
654+
}
655+
656+
unsigned long long *start_time_ptr;
657+
unsigned long long *end_time_ptr;
658+
initTimingPublic(&start_time_ptr, &end_time_ptr);
659+
660+
for (size_t worker_id = 0; worker_id < numWorkers; worker_id++) {
661+
nixl_status_t status = dispatchLaunchSingleWriteTest(GetParam(), num_threads,
662+
gpu_req_hndls[worker_id], index,
663+
0, 0, size, num_iters, is_no_delay,
664+
start_time_ptr, end_time_ptr);
665+
ASSERT_EQ(status, NIXL_SUCCESS) << "Kernel launch failed for worker " << worker_id;
666+
}
667+
668+
for (size_t worker_id = 0; worker_id < numWorkers; worker_id++) {
669+
std::vector<uint32_t> received(size / sizeof(uint32_t));
670+
cudaMemcpy(received.data(), static_cast<void *>(dst_buffers[worker_id][0]),
671+
size, cudaMemcpyDeviceToHost);
672+
673+
EXPECT_EQ(received, patterns[worker_id])
674+
<< "Worker " << worker_id << " full buffer verification failed";
675+
}
676+
677+
Logger() << "MultipleWorkers test: " << numWorkers << " workers with explicit selection verified";
678+
679+
cudaFree(start_time_ptr);
680+
cudaFree(end_time_ptr);
681+
682+
for (size_t worker_id = 0; worker_id < numWorkers; worker_id++) {
683+
getAgent(SENDER_AGENT).releaseGpuXferReq(gpu_req_hndls[worker_id]);
684+
nixl_status_t status = getAgent(SENDER_AGENT).releaseXferReq(xfer_reqs[worker_id]);
685+
EXPECT_EQ(status, NIXL_SUCCESS);
686+
}
687+
688+
invalidateMD();
689+
}
690+
575691
} // namespace gtest::nixl::gpu::single_write
576692

577693
using gtest::nixl::gpu::single_write::SingleWriteTest;

0 commit comments

Comments
 (0)