@@ -32,7 +32,7 @@ class test_ucp_device : public ucp_test {
3232 MODE_LAST_ELEM_COUNTER
3333 };
3434
35- mem_list (entity &sender, entity &receiver , size_t size, unsigned count,
35+ mem_list (test_ucp_device &test , size_t size, unsigned count,
3636 ucs_memory_type_t mem_type = UCS_MEMORY_TYPE_CUDA,
3737 mem_list_mode_t mode = MODE_DATA_ONLY);
3838 ~mem_list ();
@@ -54,7 +54,6 @@ class test_ucp_device : public ucp_test {
5454 void dst_pattern_check (unsigned index, uint64_t seed) const ;
5555
5656 private:
57- entity &m_receiver;
5857 std::vector<std::unique_ptr<mapped_buffer>> m_src, m_dst;
5958 std::vector<ucs::handle<ucp_rkey_h>> m_rkeys;
6059 ucp_device_mem_list_handle_h m_mem_list_h;
@@ -85,35 +84,28 @@ void test_ucp_device::init()
8584 if (!is_loopback ()) {
8685 receiver ().connect (&sender (), get_ep_params ());
8786 }
88-
89- ucp_device_mem_list_handle_h handle;
90- while (ucp_device_mem_list_create (sender ().ep (), NULL , &handle) ==
91- UCS_ERR_NOT_CONNECTED) {
92- progress ();
93- }
9487}
9588
96- test_ucp_device::mem_list::mem_list (entity &sender, entity &receiver ,
89+ test_ucp_device::mem_list::mem_list (test_ucp_device &test ,
9790 size_t size, unsigned count,
9891 ucs_memory_type_t mem_type,
99- mem_list_mode_t mode) :
100- m_receiver(receiver)
92+ mem_list_mode_t mode)
10193{
10294 bool has_counter = (mode != MODE_DATA_ONLY);
10395 size_t data_count = (has_counter) ? count - 1 : count;
10496
10597 // Prepare src and dst buffers
10698 for (auto i = 0 ; i < data_count; ++i) {
107- m_src.emplace_back (new mapped_buffer (size, sender, 0 , mem_type));
108- m_dst.emplace_back (new mapped_buffer (size, receiver, 0 , mem_type));
109- m_rkeys.push_back (m_dst.back ()->rkey (sender));
99+ m_src.emplace_back (new mapped_buffer (size, test. sender () , 0 , mem_type));
100+ m_dst.emplace_back (new mapped_buffer (size, test. receiver () , 0 , mem_type));
101+ m_rkeys.push_back (m_dst.back ()->rkey (test. sender () ));
110102 m_src.back ()->pattern_fill (SEED_SRC, size);
111103 m_dst.back ()->pattern_fill (SEED_DST, size);
112104 }
113105
114106 if (has_counter) {
115- m_dst.emplace_back (new mapped_buffer (size, receiver, 0 , mem_type));
116- m_rkeys.push_back (m_dst.back ()->rkey (sender));
107+ m_dst.emplace_back (new mapped_buffer (size, test. receiver () , 0 , mem_type));
108+ m_rkeys.push_back (m_dst.back ()->rkey (test. sender () ));
117109 m_dst.back ()->pattern_fill (SEED_DST, size);
118110 }
119111
@@ -150,9 +142,17 @@ test_ucp_device::mem_list::mem_list(entity &sender, entity &receiver,
150142 params.num_elements = count;
151143 params.elements = elems.data ();
152144
153- // Create memory list
154- ASSERT_UCS_OK (
155- ucp_device_mem_list_create (sender.ep (), ¶ms, &m_mem_list_h));
145+ // Create memory list (with retry on connection)
146+ ucs_status_t status = UCS_ERR_NOT_CONNECTED;
147+ test.wait_for_cond (
148+ [&]() {
149+ test.progress ();
150+ status = ucp_device_mem_list_create (test.sender ().ep (), ¶ms, &m_mem_list_h);
151+ return status != UCS_ERR_NOT_CONNECTED;
152+ },
153+ []() {}, 5.0 );
154+
155+ ASSERT_UCS_OK (status);
156156}
157157
158158test_ucp_device::mem_list::~mem_list ()
@@ -236,7 +236,7 @@ uint64_t test_ucp_device::counter_read(const mapped_buffer &buffer)
236236
237237UCS_TEST_P (test_ucp_device, create_success)
238238{
239- mem_list list (sender (), receiver () , 4 * UCS_MBYTE, 4 );
239+ mem_list list (* this , 4 * UCS_MBYTE, 4 );
240240 EXPECT_NE (nullptr , list.handle ());
241241}
242242
@@ -337,7 +337,7 @@ UCS_TEST_P(test_ucp_device, create_fail)
337337UCS_TEST_P (test_ucp_device, get_mem_list_length)
338338{
339339 constexpr unsigned num_elements = 8 ;
340- mem_list list (sender (), receiver () , 1 * UCS_KBYTE, num_elements);
340+ mem_list list (* this , 1 * UCS_KBYTE, num_elements);
341341 EXPECT_EQ (num_elements, ucp_device_get_mem_list_length (list.handle ()));
342342}
343343
@@ -548,7 +548,7 @@ class test_ucp_device_xfer : public test_ucp_device_kernel {
548548UCS_TEST_P (test_ucp_device_xfer, put_single)
549549{
550550 static constexpr size_t size = 32 * UCS_KBYTE;
551- mem_list list (sender (), receiver () , size, 6 );
551+ mem_list list (* this , size, 6 );
552552
553553 // Perform the transfer
554554 static constexpr unsigned mem_list_index = 3 ;
@@ -577,7 +577,7 @@ UCS_TEST_SKIP_COND_P(test_ucp_device_xfer, put_single_stress_test,
577577
578578 static constexpr size_t size = 8 ;
579579 static constexpr unsigned mem_list_index = 0 ;
580- mem_list list (sender (), receiver () , size, 1 );
580+ mem_list list (* this , size, 1 );
581581
582582 // Perform the transfer
583583 auto params = init_params ();
@@ -601,8 +601,8 @@ UCS_TEST_P(test_ucp_device_xfer, put_multi)
601601{
602602 static constexpr size_t size = 32 * UCS_KBYTE;
603603 unsigned count = get_multi_elem_count ();
604- mem_list list (sender (), receiver (), size, count + 1 , UCS_MEMORY_TYPE_CUDA ,
605- mem_list::MODE_LAST_ELEM_COUNTER);
604+ mem_list list (* this , size, count + 1 ,
605+ UCS_MEMORY_TYPE_CUDA, mem_list::MODE_LAST_ELEM_COUNTER);
606606
607607 const unsigned counter_index = count;
608608 list.dst_counter_init (counter_index);
@@ -631,7 +631,7 @@ UCS_TEST_SKIP_COND_P(test_ucp_device_xfer, put_multi_stress_test,
631631
632632 static constexpr size_t size = 8 ;
633633 unsigned count = get_multi_elem_count ();
634- mem_list list (sender (), receiver () , size, count + 1 );
634+ mem_list list (* this , size, count + 1 );
635635
636636 const unsigned counter_index = count;
637637 list.dst_counter_init (counter_index);
@@ -657,8 +657,8 @@ UCS_TEST_P(test_ucp_device_xfer, put_multi_partial)
657657{
658658 static constexpr size_t size = 32 * UCS_KBYTE;
659659 unsigned total_count = get_multi_elem_count () * 2 ;
660- mem_list list (sender (), receiver (), size, total_count + 1 , UCS_MEMORY_TYPE_CUDA ,
661- mem_list::MODE_LAST_ELEM_COUNTER);
660+ mem_list list (* this , size, total_count + 1 ,
661+ UCS_MEMORY_TYPE_CUDA, mem_list::MODE_LAST_ELEM_COUNTER);
662662
663663 const unsigned counter_index = total_count;
664664 list.dst_counter_init (counter_index);
@@ -708,7 +708,7 @@ UCS_TEST_P(test_ucp_device_xfer, put_multi_partial)
708708UCS_TEST_P (test_ucp_device_xfer, counter)
709709{
710710 const size_t size = counter_size ();
711- mem_list list (sender (), receiver () , size, 1 , UCS_MEMORY_TYPE_CUDA,
711+ mem_list list (* this , size, 1 , UCS_MEMORY_TYPE_CUDA,
712712 mem_list::MODE_COUNTER_ONLY);
713713
714714 static constexpr unsigned mem_list_index = 0 ;
0 commit comments