@@ -149,7 +149,7 @@ protected:
149149 nixl_b_params_t params;
150150
151151 if (getBackendName () == " UCX" ) {
152- params[" num_workers" ] = " 2 " ;
152+ params[" num_workers" ] = std::to_string (numWorkers) ;
153153 }
154154
155155 return params;
@@ -194,21 +194,50 @@ protected:
194194 agent.registerMem (reg_list);
195195 }
196196
197+ // TODO: remove this function once a blocking CreateGpuXferReq is implemented
197198 void
198- completeWireup (size_t from_agent, size_t to_agent) {
199- nixl_notifs_t notifs;
200- nixl_status_t status = getAgent (from_agent).genNotif (getAgentName (to_agent), NOTIF_MSG);
201- ASSERT_EQ (status, NIXL_SUCCESS) << " Failed to complete wireup" ;
202-
203- do {
204- nixl_status_t ret = getAgent (to_agent).getNotifs (notifs);
205- ASSERT_EQ (ret, NIXL_SUCCESS) << " Failed to get notifications during wireup" ;
206- std::this_thread::sleep_for (std::chrono::milliseconds (10 ));
207- } while (notifs.size () == 0 );
199+ completeWireup (size_t from_agent, size_t to_agent,
200+ const std::vector<MemBuffer> &wireup_src,
201+ const std::vector<MemBuffer> &wireup_dst) {
202+ nixl_opt_args_t wireup_params;
203+
204+ for (size_t worker_id = 0 ; worker_id < numWorkers; worker_id++) {
205+ wireup_params.customParam = " worker_id=" + std::to_string (worker_id);
206+
207+ nixlXferReqH *wireup_req;
208+ nixl_status_t status = getAgent (from_agent)
209+ .createXferReq (NIXL_WRITE,
210+ makeDescList<nixlBasicDesc>(wireup_src, VRAM_SEG),
211+ makeDescList<nixlBasicDesc>(wireup_dst, VRAM_SEG),
212+ getAgentName (to_agent),
213+ wireup_req,
214+ &wireup_params);
215+
216+ ASSERT_EQ (status, NIXL_SUCCESS) << " Failed to create wireup request for worker " << worker_id;
217+
218+ status = getAgent (from_agent).postXferReq (wireup_req);
219+ ASSERT_TRUE (status == NIXL_SUCCESS || status == NIXL_IN_PROG)
220+ << " Failed to post wireup for worker " << worker_id;
221+
222+ nixl_status_t xfer_status;
223+ do {
224+ xfer_status = getAgent (from_agent).getXferStatus (wireup_req);
225+ std::this_thread::sleep_for (std::chrono::milliseconds (1 ));
226+ } while (xfer_status == NIXL_IN_PROG);
227+
228+ ASSERT_EQ (xfer_status, NIXL_SUCCESS) << " Warmup failed for worker " << worker_id;
229+
230+ status = getAgent (from_agent).releaseXferReq (wireup_req);
231+ ASSERT_EQ (status, NIXL_SUCCESS);
232+ }
208233 }
209234
210235 void
211236 exchangeMD (size_t from_agent, size_t to_agent) {
237+ std::vector<MemBuffer> wireup_src, wireup_dst;
238+ createRegisteredMem (getAgent (from_agent), 64 , 1 , VRAM_SEG, wireup_src);
239+ createRegisteredMem (getAgent (to_agent), 64 , 1 , VRAM_SEG, wireup_dst);
240+
212241 for (size_t i = 0 ; i < agents.size (); i++) {
213242 nixl_blob_t md;
214243 nixl_status_t status = agents[i]->getLocalMD (md);
@@ -223,7 +252,7 @@ protected:
223252 }
224253 }
225254
226- completeWireup (from_agent, to_agent);
255+ completeWireup (from_agent, to_agent, wireup_src, wireup_dst );
227256 }
228257
229258 void
@@ -316,6 +345,7 @@ protected:
316345protected:
317346 static constexpr size_t SENDER_AGENT = 0 ;
318347 static constexpr size_t RECEIVER_AGENT = 1 ;
348+ static constexpr size_t numWorkers = 32 ;
319349
320350private:
321351 static constexpr uint64_t DEV_ID = 0 ;
@@ -572,6 +602,92 @@ TEST_P(SingleWriteTest, VariableSizeTest) {
572602 }
573603}
574604
605+ TEST_P (SingleWriteTest, MultipleWorkersTest) {
606+ constexpr size_t size = 4 * 1024 ;
607+ constexpr size_t num_iters = 100 ;
608+ constexpr unsigned index = 0 ;
609+ constexpr bool is_no_delay = true ;
610+ constexpr nixl_mem_t mem_type = VRAM_SEG;
611+ constexpr size_t num_threads = 32 ;
612+
613+ std::vector<std::vector<MemBuffer>> src_buffers (numWorkers);
614+ std::vector<std::vector<MemBuffer>> dst_buffers (numWorkers);
615+ std::vector<std::vector<uint32_t >> patterns (numWorkers);
616+
617+ for (size_t worker_id = 0 ; worker_id < numWorkers; worker_id++) {
618+ createRegisteredMem (getAgent (SENDER_AGENT), size, 1 , mem_type, src_buffers[worker_id]);
619+ createRegisteredMem (getAgent (RECEIVER_AGENT), size, 1 , mem_type, dst_buffers[worker_id]);
620+
621+ constexpr size_t num_elements = size / sizeof (uint32_t );
622+ patterns[worker_id].resize (num_elements);
623+ for (size_t i = 0 ; i < num_elements; i++) {
624+ patterns[worker_id][i] = 0xDEAD0000 | worker_id;
625+ }
626+ cudaMemcpy (static_cast <void *>(src_buffers[worker_id][0 ]), patterns[worker_id].data (),
627+ size, cudaMemcpyHostToDevice);
628+ }
629+
630+ exchangeMD (SENDER_AGENT, RECEIVER_AGENT);
631+
632+ nixl_opt_args_t extra_params = {};
633+ extra_params.hasNotif = true ;
634+ extra_params.notifMsg = NOTIF_MSG;
635+
636+ std::vector<nixlXferReqH *> xfer_reqs (numWorkers);
637+ std::vector<nixlGpuXferReqH> gpu_req_hndls (numWorkers);
638+
639+ for (size_t worker_id = 0 ; worker_id < numWorkers; worker_id++) {
640+ extra_params.customParam = " worker_id=" + std::to_string (worker_id);
641+
642+ nixl_status_t status = getAgent (SENDER_AGENT)
643+ .createXferReq (NIXL_WRITE,
644+ makeDescList<nixlBasicDesc>(src_buffers[worker_id], mem_type),
645+ makeDescList<nixlBasicDesc>(dst_buffers[worker_id], mem_type),
646+ getAgentName (RECEIVER_AGENT),
647+ xfer_reqs[worker_id],
648+ &extra_params);
649+
650+ ASSERT_EQ (status, NIXL_SUCCESS) << " Failed to create xfer request for worker " << worker_id;
651+
652+ status = getAgent (SENDER_AGENT).createGpuXferReq (*xfer_reqs[worker_id], gpu_req_hndls[worker_id]);
653+ ASSERT_EQ (status, NIXL_SUCCESS) << " Failed to create GPU xfer request for worker " << worker_id;
654+ }
655+
656+ unsigned long long *start_time_ptr;
657+ unsigned long long *end_time_ptr;
658+ initTimingPublic (&start_time_ptr, &end_time_ptr);
659+
660+ for (size_t worker_id = 0 ; worker_id < numWorkers; worker_id++) {
661+ nixl_status_t status = dispatchLaunchSingleWriteTest (GetParam (), num_threads,
662+ gpu_req_hndls[worker_id], index,
663+ 0 , 0 , size, num_iters, is_no_delay,
664+ start_time_ptr, end_time_ptr);
665+ ASSERT_EQ (status, NIXL_SUCCESS) << " Kernel launch failed for worker " << worker_id;
666+ }
667+
668+ for (size_t worker_id = 0 ; worker_id < numWorkers; worker_id++) {
669+ std::vector<uint32_t > received (size / sizeof (uint32_t ));
670+ cudaMemcpy (received.data (), static_cast <void *>(dst_buffers[worker_id][0 ]),
671+ size, cudaMemcpyDeviceToHost);
672+
673+ EXPECT_EQ (received, patterns[worker_id])
674+ << " Worker " << worker_id << " full buffer verification failed" ;
675+ }
676+
677+ Logger () << " MultipleWorkers test: " << numWorkers << " workers with explicit selection verified" ;
678+
679+ cudaFree (start_time_ptr);
680+ cudaFree (end_time_ptr);
681+
682+ for (size_t worker_id = 0 ; worker_id < numWorkers; worker_id++) {
683+ getAgent (SENDER_AGENT).releaseGpuXferReq (gpu_req_hndls[worker_id]);
684+ nixl_status_t status = getAgent (SENDER_AGENT).releaseXferReq (xfer_reqs[worker_id]);
685+ EXPECT_EQ (status, NIXL_SUCCESS);
686+ }
687+
688+ invalidateMD ();
689+ }
690+
575691} // namespace gtest::nixl::gpu::single_write
576692
577693using gtest::nixl::gpu::single_write::SingleWriteTest;
0 commit comments