Skip to content

Commit 4370454

Browse files
authored
[coll] Reduce the amount of open files (socket). (dmlc#10693)
Reduce the chance of hitting `Failed to call `socket`: Too many open files`.
1 parent d414fdf commit 4370454

File tree

3 files changed

+23
-7
lines changed

3 files changed

+23
-7
lines changed

src/collective/comm.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ Result ConnectTrackerImpl(proto::PeerInfo info, std::chrono::seconds timeout, st
141141

142142
for (std::int32_t r = (comm.Rank() + 1); r < comm.World(); ++r) {
143143
auto const& peer = peers[r];
144-
std::shared_ptr<TCPSocket> worker{TCPSocket::CreatePtr(comm.Domain())};
144+
auto worker = std::make_shared<TCPSocket>();
145145
rc = std::move(rc)
146146
<< [&] { return Connect(peer.host, peer.port, retry, timeout, worker.get()); }
147147
<< [&] { return worker->RecvTimeout(timeout); };
@@ -161,7 +161,7 @@ Result ConnectTrackerImpl(proto::PeerInfo info, std::chrono::seconds timeout, st
161161
}
162162

163163
for (std::int32_t r = 0; r < comm.Rank(); ++r) {
164-
auto peer = std::shared_ptr<TCPSocket>(TCPSocket::CreatePtr(comm.Domain()));
164+
auto peer = std::make_shared<TCPSocket>();
165165
rc = std::move(rc) << [&] {
166166
SockAddress addr;
167167
return listener->Accept(peer.get(), &addr);

src/collective/socket.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,9 @@ std::size_t TCPSocket::Send(StringView str) {
118118
addr_len = sizeof(addr.V6().Handle());
119119
}
120120

121-
conn = TCPSocket::Create(addr.Domain());
121+
if (conn.IsClosed()) {
122+
conn = TCPSocket::Create(addr.Domain());
123+
}
122124
CHECK_EQ(static_cast<std::int32_t>(conn.Domain()), static_cast<std::int32_t>(addr.Domain()));
123125
auto non_blocking = conn.NonBlocking();
124126
auto rc = conn.NonBlocking(true);

tests/cpp/collective/test_coll_c_api.cc

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include <chrono> // for ""s
88
#include <thread> // for thread
99

10+
#include "../../../src/collective/allgather.h" // for RingAllgather
1011
#include "../../../src/collective/tracker.h"
1112
#include "test_worker.h" // for SocketTest
1213
#include "xgboost/json.h" // for Json
@@ -19,8 +20,9 @@ class TrackerAPITest : public SocketTest {};
1920
TEST_F(TrackerAPITest, CAPI) {
2021
TrackerHandle handle;
2122
Json config{Object{}};
23+
std::int32_t n_workers{2};
2224
config["dmlc_communicator"] = String{"rabit"};
23-
config["n_workers"] = 2;
25+
config["n_workers"] = n_workers;
2426
config["timeout"] = 1;
2527
auto config_str = Json::Dump(config);
2628
auto rc = XGTrackerCreate(config_str.c_str(), &handle);
@@ -47,9 +49,21 @@ TEST_F(TrackerAPITest, CAPI) {
4749
ASSERT_NE(port, 0);
4850

4951
std::vector<std::thread> workers;
50-
using namespace std::chrono_literals; // NOLINT
51-
for (std::int32_t r = 0; r < 2; ++r) {
52-
workers.emplace_back([=] { WorkerForTest w{host, static_cast<std::int32_t>(port), 1s, 2, r}; });
52+
using std::chrono_literals::operator""s;
53+
for (std::int32_t r = 0; r < n_workers; ++r) {
54+
workers.emplace_back([=] {
55+
WorkerForTest w{host, static_cast<std::int32_t>(port), 8s, n_workers, r};
56+
// basic test
57+
std::vector<std::int32_t> data(w.Comm().World(), 0);
58+
data[w.Comm().Rank()] = w.Comm().Rank();
59+
60+
auto rc = RingAllgather(w.Comm(), common::Span{data.data(), data.size()});
61+
SafeColl(rc);
62+
63+
for (std::int32_t r = 0; r < w.Comm().World(); ++r) {
64+
ASSERT_EQ(data[r], r);
65+
}
66+
});
5367
}
5468
for (auto& w : workers) {
5569
w.join();

0 commit comments

Comments
 (0)