Skip to content

Commit d891a0e

Browse files
michal-shalevzzhang37
authored andcommitted
UCP/DEVICE: Remove redundant UCP_EP_FLAG_REMOTE_CONNECTED check (openucx#10985)
1 parent 90532bc commit d891a0e

File tree

3 files changed

+40
-41
lines changed

3 files changed

+40
-41
lines changed

src/tools/perf/cuda/ucp_cuda_kernel.cu

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -182,9 +182,17 @@ private:
182182
params.num_elements = count;
183183
params.elements = elems;
184184

185-
ucs_status_t status = ucp_device_mem_list_create(perf.ucp.ep, &params,
186-
&m_params.mem_list);
187-
if (status != UCS_OK) {
185+
ucs_status_t status;
186+
const ucs_time_t deadline = ucs_get_time() + ucs_time_from_sec(5.0);
187+
do {
188+
ucp_worker_progress(perf.ucp.worker);
189+
status = ucp_device_mem_list_create(perf.ucp.ep, &params,
190+
&m_params.mem_list);
191+
} while ((status == UCS_ERR_NOT_CONNECTED) && (ucs_get_time() < deadline));
192+
193+
if (status == UCS_ERR_NOT_CONNECTED) {
194+
throw std::runtime_error("Timeout waiting for connection");
195+
} else if (status != UCS_OK) {
188196
throw std::runtime_error("Failed to create memory list");
189197
}
190198
}

src/ucp/core/ucp_device.c

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -549,15 +549,6 @@ ucp_device_mem_list_create(ucp_ep_h ep,
549549
ucp_ep_config_t *ep_config;
550550
uct_allocated_memory_t mem;
551551

552-
if (!(ep->flags & UCP_EP_FLAG_REMOTE_CONNECTED)) {
553-
/*
554-
* Do not log error here because UCS_ERR_NOT_CONNECTED is expected
555-
* during connection establishment. Applications are expected to retry
556-
* with progress.
557-
*/
558-
return UCS_ERR_NOT_CONNECTED;
559-
}
560-
561552
/* Parameter sanity checks and extraction */
562553
status = ucp_device_mem_list_params_check(ep->worker->context, params,
563554
&rkey_cfg_index, &local_sys_dev,

test/gtest/ucp/test_ucp_device.cc

Lines changed: 29 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ class test_ucp_device : public ucp_test {
3232
MODE_LAST_ELEM_COUNTER
3333
};
3434

35-
mem_list(entity &sender, entity &receiver, size_t size, unsigned count,
35+
mem_list(test_ucp_device &test, size_t size, unsigned count,
3636
ucs_memory_type_t mem_type = UCS_MEMORY_TYPE_CUDA,
3737
mem_list_mode_t mode = MODE_DATA_ONLY);
3838
~mem_list();
@@ -54,7 +54,6 @@ class test_ucp_device : public ucp_test {
5454
void dst_pattern_check(unsigned index, uint64_t seed) const;
5555

5656
private:
57-
entity &m_receiver;
5857
std::vector<std::unique_ptr<mapped_buffer>> m_src, m_dst;
5958
std::vector<ucs::handle<ucp_rkey_h>> m_rkeys;
6059
ucp_device_mem_list_handle_h m_mem_list_h;
@@ -85,35 +84,28 @@ void test_ucp_device::init()
8584
if (!is_loopback()) {
8685
receiver().connect(&sender(), get_ep_params());
8786
}
88-
89-
ucp_device_mem_list_handle_h handle;
90-
while (ucp_device_mem_list_create(sender().ep(), NULL, &handle) ==
91-
UCS_ERR_NOT_CONNECTED) {
92-
progress();
93-
}
9487
}
9588

96-
test_ucp_device::mem_list::mem_list(entity &sender, entity &receiver,
89+
test_ucp_device::mem_list::mem_list(test_ucp_device &test,
9790
size_t size, unsigned count,
9891
ucs_memory_type_t mem_type,
99-
mem_list_mode_t mode) :
100-
m_receiver(receiver)
92+
mem_list_mode_t mode)
10193
{
10294
bool has_counter = (mode != MODE_DATA_ONLY);
10395
size_t data_count = (has_counter) ? count - 1 : count;
10496

10597
// Prepare src and dst buffers
10698
for (auto i = 0; i < data_count; ++i) {
107-
m_src.emplace_back(new mapped_buffer(size, sender, 0, mem_type));
108-
m_dst.emplace_back(new mapped_buffer(size, receiver, 0, mem_type));
109-
m_rkeys.push_back(m_dst.back()->rkey(sender));
99+
m_src.emplace_back(new mapped_buffer(size, test.sender(), 0, mem_type));
100+
m_dst.emplace_back(new mapped_buffer(size, test.receiver(), 0, mem_type));
101+
m_rkeys.push_back(m_dst.back()->rkey(test.sender()));
110102
m_src.back()->pattern_fill(SEED_SRC, size);
111103
m_dst.back()->pattern_fill(SEED_DST, size);
112104
}
113105

114106
if (has_counter) {
115-
m_dst.emplace_back(new mapped_buffer(size, receiver, 0, mem_type));
116-
m_rkeys.push_back(m_dst.back()->rkey(sender));
107+
m_dst.emplace_back(new mapped_buffer(size, test.receiver(), 0, mem_type));
108+
m_rkeys.push_back(m_dst.back()->rkey(test.sender()));
117109
m_dst.back()->pattern_fill(SEED_DST, size);
118110
}
119111

@@ -150,9 +142,17 @@ test_ucp_device::mem_list::mem_list(entity &sender, entity &receiver,
150142
params.num_elements = count;
151143
params.elements = elems.data();
152144

153-
// Create memory list
154-
ASSERT_UCS_OK(
155-
ucp_device_mem_list_create(sender.ep(), &params, &m_mem_list_h));
145+
// Create memory list (with retry on connection)
146+
ucs_status_t status = UCS_ERR_NOT_CONNECTED;
147+
test.wait_for_cond(
148+
[&]() {
149+
test.progress();
150+
status = ucp_device_mem_list_create(test.sender().ep(), &params, &m_mem_list_h);
151+
return status != UCS_ERR_NOT_CONNECTED;
152+
},
153+
[]() {}, 5.0);
154+
155+
ASSERT_UCS_OK(status);
156156
}
157157

158158
test_ucp_device::mem_list::~mem_list()
@@ -236,7 +236,7 @@ uint64_t test_ucp_device::counter_read(const mapped_buffer &buffer)
236236

237237
UCS_TEST_P(test_ucp_device, create_success)
238238
{
239-
mem_list list(sender(), receiver(), 4 * UCS_MBYTE, 4);
239+
mem_list list(*this, 4 * UCS_MBYTE, 4);
240240
EXPECT_NE(nullptr, list.handle());
241241
}
242242

@@ -337,7 +337,7 @@ UCS_TEST_P(test_ucp_device, create_fail)
337337
UCS_TEST_P(test_ucp_device, get_mem_list_length)
338338
{
339339
constexpr unsigned num_elements = 8;
340-
mem_list list(sender(), receiver(), 1 * UCS_KBYTE, num_elements);
340+
mem_list list(*this, 1 * UCS_KBYTE, num_elements);
341341
EXPECT_EQ(num_elements, ucp_device_get_mem_list_length(list.handle()));
342342
}
343343

@@ -548,7 +548,7 @@ class test_ucp_device_xfer : public test_ucp_device_kernel {
548548
UCS_TEST_P(test_ucp_device_xfer, put_single)
549549
{
550550
static constexpr size_t size = 32 * UCS_KBYTE;
551-
mem_list list(sender(), receiver(), size, 6);
551+
mem_list list(*this, size, 6);
552552

553553
// Perform the transfer
554554
static constexpr unsigned mem_list_index = 3;
@@ -577,7 +577,7 @@ UCS_TEST_SKIP_COND_P(test_ucp_device_xfer, put_single_stress_test,
577577

578578
static constexpr size_t size = 8;
579579
static constexpr unsigned mem_list_index = 0;
580-
mem_list list(sender(), receiver(), size, 1);
580+
mem_list list(*this, size, 1);
581581

582582
// Perform the transfer
583583
auto params = init_params();
@@ -601,8 +601,8 @@ UCS_TEST_P(test_ucp_device_xfer, put_multi)
601601
{
602602
static constexpr size_t size = 32 * UCS_KBYTE;
603603
unsigned count = get_multi_elem_count();
604-
mem_list list(sender(), receiver(), size, count + 1, UCS_MEMORY_TYPE_CUDA,
605-
mem_list::MODE_LAST_ELEM_COUNTER);
604+
mem_list list(*this, size, count + 1,
605+
UCS_MEMORY_TYPE_CUDA, mem_list::MODE_LAST_ELEM_COUNTER);
606606

607607
const unsigned counter_index = count;
608608
list.dst_counter_init(counter_index);
@@ -631,7 +631,7 @@ UCS_TEST_SKIP_COND_P(test_ucp_device_xfer, put_multi_stress_test,
631631

632632
static constexpr size_t size = 8;
633633
unsigned count = get_multi_elem_count();
634-
mem_list list(sender(), receiver(), size, count + 1);
634+
mem_list list(*this, size, count + 1);
635635

636636
const unsigned counter_index = count;
637637
list.dst_counter_init(counter_index);
@@ -657,8 +657,8 @@ UCS_TEST_P(test_ucp_device_xfer, put_multi_partial)
657657
{
658658
static constexpr size_t size = 32 * UCS_KBYTE;
659659
unsigned total_count = get_multi_elem_count() * 2;
660-
mem_list list(sender(), receiver(), size, total_count + 1, UCS_MEMORY_TYPE_CUDA,
661-
mem_list::MODE_LAST_ELEM_COUNTER);
660+
mem_list list(*this, size, total_count + 1,
661+
UCS_MEMORY_TYPE_CUDA, mem_list::MODE_LAST_ELEM_COUNTER);
662662

663663
const unsigned counter_index = total_count;
664664
list.dst_counter_init(counter_index);
@@ -708,7 +708,7 @@ UCS_TEST_P(test_ucp_device_xfer, put_multi_partial)
708708
UCS_TEST_P(test_ucp_device_xfer, counter)
709709
{
710710
const size_t size = counter_size();
711-
mem_list list(sender(), receiver(), size, 1, UCS_MEMORY_TYPE_CUDA,
711+
mem_list list(*this, size, 1, UCS_MEMORY_TYPE_CUDA,
712712
mem_list::MODE_COUNTER_ONLY);
713713

714714
static constexpr unsigned mem_list_index = 0;

0 commit comments

Comments
 (0)