Skip to content

Commit 617a463

Browse files
authored
Fix build and C++ tests for FreeBSD (dmlc#10480) (dmlc#10501)
1 parent d482ba1 commit 617a463

File tree

8 files changed

+79
-11
lines changed

8 files changed

+79
-11
lines changed

.github/workflows/freebsd.yml

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
name: FreeBSD
2+
3+
on: [push, pull_request]
4+
5+
permissions:
6+
contents: read # to fetch code (actions/checkout)
7+
8+
concurrency:
9+
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
10+
cancel-in-progress: true
11+
12+
jobs:
13+
test:
14+
runs-on: ubuntu-latest
15+
name: A job to run test in FreeBSD
16+
steps:
17+
- uses: actions/checkout@v4
18+
with:
19+
submodules: 'true'
20+
- name: Test in FreeBSD
21+
id: test
22+
uses: vmactions/freebsd-vm@v1
23+
with:
24+
usesh: true
25+
prepare: |
26+
pkg install -y cmake git ninja googletest
27+
28+
run: |
29+
mkdir build
30+
cd build
31+
cmake .. -GNinja -DGOOGLE_TEST=ON
32+
ninja -v
33+
./testxgboost

rabit/include/rabit/internal/socket.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,8 @@ namespace utils {
7878

7979
template <typename PollFD>
8080
int PollImpl(PollFD* pfd, int nfds, std::chrono::seconds timeout) noexcept(true) {
81+
// For Windows and Linux, negative timeout means infinite timeout. For freebsd,
82+
// INFTIM(-1) should be used instead.
8183
#if defined(_WIN32)
8284

8385
#if IS_MINGW()
@@ -88,7 +90,7 @@ int PollImpl(PollFD* pfd, int nfds, std::chrono::seconds timeout) noexcept(true)
8890
#endif // IS_MINGW()
8991

9092
#else
91-
return poll(pfd, nfds, std::chrono::milliseconds(timeout).count());
93+
return poll(pfd, nfds, timeout.count() < 0 ? -1 : std::chrono::milliseconds(timeout).count());
9294
#endif // IS_MINGW()
9395
}
9496

src/c_api/coll_c_api.cc

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,11 @@ using CollAPIThreadLocalStore = dmlc::ThreadLocalStore<CollAPIEntry>;
7575

7676
void WaitImpl(TrackerHandleT *ptr, std::chrono::seconds timeout) {
7777
constexpr std::int64_t kDft{collective::DefaultTimeoutSec()};
78-
std::chrono::seconds wait_for{collective::HasTimeout(timeout) ? std::min(kDft, timeout.count())
79-
: kDft};
78+
std::int64_t timeout_clipped = kDft;
79+
if (collective::HasTimeout(timeout)) {
80+
timeout_clipped = std::min(kDft, static_cast<std::int64_t>(timeout.count()));
81+
}
82+
std::chrono::seconds wait_for{timeout_clipped};
8083

8184
common::Timer timer;
8285
timer.Start();
@@ -171,7 +174,19 @@ XGB_DLL int XGTrackerFree(TrackerHandle handle) {
171174
common::Timer timer;
172175
timer.Start();
173176
// Make sure no one else is waiting on the tracker.
174-
while (!ptr->first.unique()) {
177+
178+
// Quote from https://en.cppreference.com/w/cpp/memory/shared_ptr/use_count#Notes:
179+
//
180+
// In multithreaded environment, `use_count() == 1` does not imply that the object is
181+
// safe to modify because accesses to the managed object by former shared owners may not
182+
// have completed, and because new shared owners may be introduced concurrently.
183+
//
184+
// - We don't have the first case since we never access the raw pointer.
185+
//
186+
// - We don't hve the second case for most of the scenarios since tracker is an unique
187+
// object, if the free function is called before another function calls, it's likely
188+
// to be a bug in the user code. The use_count should only decrease in this function.
189+
while (ptr->first.use_count() != 1) {
175190
auto ela = timer.Duration().count();
176191
if (collective::HasTimeout(ptr->first->Timeout()) && ela > ptr->first->Timeout().count()) {
177192
LOG(WARNING) << "Time out " << ptr->first->Timeout().count()

src/collective/socket.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,12 @@ namespace xgboost::collective {
2222
SockAddress MakeSockAddress(StringView host, in_port_t port) {
2323
struct addrinfo hints;
2424
std::memset(&hints, 0, sizeof(hints));
25-
hints.ai_protocol = SOCK_STREAM;
25+
hints.ai_socktype = SOCK_STREAM;
2626
struct addrinfo *res = nullptr;
2727
int sig = getaddrinfo(host.c_str(), nullptr, &hints, &res);
2828
if (sig != 0) {
29+
LOG(FATAL) << "Failed to get addr info for: " << host
30+
<< ", error: " << gai_strerror(sig);
2931
return {};
3032
}
3133
if (res->ai_family == static_cast<std::int32_t>(SockDomain::kV4)) {

tests/cpp/collective/test_worker.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ inline Json MakeTrackerConfig(std::string host, std::int32_t n_workers,
105105
config["port"] = Integer{0};
106106
config["n_workers"] = Integer{n_workers};
107107
config["sortby"] = Integer{static_cast<std::int32_t>(Tracker::SortBy::kHost)};
108-
config["timeout"] = timeout.count();
108+
config["timeout"] = static_cast<std::int64_t>(timeout.count());
109109
return config;
110110
}
111111

tests/cpp/common/test_random.cc

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,14 +68,20 @@ TEST(ColumnSampler, GPUTest) {
6868
// Test if different threads using the same seed produce the same result
6969
TEST(ColumnSampler, ThreadSynchronisation) {
7070
Context ctx;
71-
const int64_t num_threads = 100;
71+
// NOLINTBEGIN(clang-analyzer-deadcode.DeadStores)
72+
#if defined(__linux__)
73+
std::int64_t const n_threads = std::thread::hardware_concurrency() * 128;
74+
#else
75+
std::int64_t const n_threads = std::thread::hardware_concurrency();
76+
#endif
77+
// NOLINTEND(clang-analyzer-deadcode.DeadStores)
7278
int n = 128;
7379
size_t iterations = 10;
7480
size_t levels = 5;
7581
std::vector<bst_feature_t> reference_result;
7682
std::vector<float> feature_weights;
7783
bool success = true; // Cannot use google test asserts in multithreaded region
78-
#pragma omp parallel num_threads(num_threads)
84+
#pragma omp parallel num_threads(n_threads)
7985
{
8086
for (auto j = 0ull; j < iterations; j++) {
8187
ColumnSampler cs(j);

tests/cpp/test_cache.cc

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,11 @@ TEST(DMatrixCache, MultiThread) {
5959
std::size_t constexpr kRows = 2, kCols = 1, kCacheSize = 3;
6060
auto p_fmat = RandomDataGenerator(kRows, kCols, 0).GenerateDMatrix();
6161

62-
auto n = std::thread::hardware_concurrency() * 128u;
62+
#if defined(__linux__)
63+
auto const n = std::thread::hardware_concurrency() * 128;
64+
#else
65+
auto const n = std::thread::hardware_concurrency();
66+
#endif
6367
CHECK_NE(n, 0);
6468
std::vector<std::shared_ptr<CacheForTest>> results(n);
6569

tests/cpp/test_learner.cc

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -267,8 +267,14 @@ TEST(Learner, MultiThreadedPredict) {
267267
learner->Configure();
268268

269269
std::vector<std::thread> threads;
270-
for (uint32_t thread_id = 0;
271-
thread_id < 2 * std::thread::hardware_concurrency(); ++thread_id) {
270+
271+
#if defined(__linux__)
272+
auto n_threads = std::thread::hardware_concurrency() * 4u;
273+
#else
274+
auto n_threads = std::thread::hardware_concurrency();
275+
#endif
276+
277+
for (decltype(n_threads) thread_id = 0; thread_id < n_threads; ++thread_id) {
272278
threads.emplace_back([learner, p_data] {
273279
size_t constexpr kIters = 10;
274280
auto &entry = learner->GetThreadLocal().prediction_entry;

0 commit comments

Comments
 (0)