Skip to content

Commit 835e59e

Browse files
authored
Use a thread pool for external memory. (dmlc#10288)
1 parent ee2afb3 commit 835e59e

File tree

3 files changed

+157
-5
lines changed

3 files changed

+157
-5
lines changed

src/common/threadpool.h

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
/**
2+
* Copyright 2024, XGBoost Contributors
3+
*/
4+
#pragma once
5+
#include <condition_variable> // for condition_variable
6+
#include <cstdint> // for int32_t
7+
#include <functional> // for function
8+
#include <future> // for promise
9+
#include <memory> // for make_shared
10+
#include <mutex> // for mutex, unique_lock
11+
#include <queue> // for queue
12+
#include <thread> // for thread
13+
#include <type_traits> // for invoke_result_t
14+
#include <utility> // for move
15+
#include <vector> // for vector
16+
17+
namespace xgboost::common {
18+
/**
19+
* @brief Simple implementation of a thread pool.
20+
*/
21+
class ThreadPool {
22+
std::mutex mu_;
23+
std::queue<std::function<void()>> tasks_;
24+
std::condition_variable cv_;
25+
std::vector<std::thread> pool_;
26+
bool stop_{false};
27+
28+
public:
29+
explicit ThreadPool(std::int32_t n_threads) {
30+
for (std::int32_t i = 0; i < n_threads; ++i) {
31+
pool_.emplace_back([&] {
32+
while (true) {
33+
std::unique_lock lock{mu_};
34+
cv_.wait(lock, [this] { return !this->tasks_.empty() || stop_; });
35+
36+
if (this->stop_) {
37+
if (!tasks_.empty()) {
38+
while (!tasks_.empty()) {
39+
auto fn = tasks_.front();
40+
tasks_.pop();
41+
fn();
42+
}
43+
}
44+
return;
45+
}
46+
47+
auto fn = tasks_.front();
48+
tasks_.pop();
49+
lock.unlock();
50+
fn();
51+
}
52+
});
53+
}
54+
}
55+
56+
~ThreadPool() {
57+
std::unique_lock lock{mu_};
58+
stop_ = true;
59+
lock.unlock();
60+
61+
for (auto& t : pool_) {
62+
if (t.joinable()) {
63+
std::unique_lock lock{mu_};
64+
this->cv_.notify_one();
65+
lock.unlock();
66+
}
67+
}
68+
69+
for (auto& t : pool_) {
70+
if (t.joinable()) {
71+
t.join();
72+
}
73+
}
74+
}
75+
76+
/**
77+
* @brief Submit a function that doesn't take any argument.
78+
*/
79+
template <typename Fn, typename R = std::invoke_result_t<Fn>>
80+
auto Submit(Fn&& fn) {
81+
// Use shared ptr to make the task copy constructible.
82+
auto p{std::make_shared<std::promise<R>>()};
83+
auto fut = p->get_future();
84+
auto ffn = std::function{[task = std::move(p), fn = std::move(fn)]() mutable {
85+
task->set_value(fn());
86+
}};
87+
88+
std::unique_lock lock{mu_};
89+
this->tasks_.push(std::move(ffn));
90+
lock.unlock();
91+
92+
cv_.notify_one();
93+
return fut;
94+
}
95+
};
96+
} // namespace xgboost::common

src/data/sparse_page_source.h

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#endif // !defined(XGBOOST_USE_CUDA)
2121

2222
#include "../common/io.h" // for PrivateMmapConstStream
23+
#include "../common/threadpool.h" // for ThreadPool
2324
#include "../common/timer.h" // for Monitor, Timer
2425
#include "proxy_dmatrix.h" // for DMatrixProxy
2526
#include "sparse_page_writer.h" // for SparsePageFormat
@@ -148,6 +149,8 @@ class SparsePageSourceImpl : public BatchIteratorImpl<S> {
148149
std::mutex single_threaded_;
149150
// The current page.
150151
std::shared_ptr<S> page_;
152+
// Workers for fetching data from external memory.
153+
common::ThreadPool workers_;
151154

152155
bool at_end_ {false};
153156
float missing_;
@@ -161,8 +164,8 @@ class SparsePageSourceImpl : public BatchIteratorImpl<S> {
161164
std::shared_ptr<Cache> cache_info_;
162165

163166
using Ring = std::vector<std::future<std::shared_ptr<S>>>;
164-
// A ring storing futures to data. Since the DMatrix iterator is forward only, so we
165-
// can pre-fetch data in a ring.
167+
// A ring storing futures to data. Since the DMatrix iterator is forward only, we can
168+
// pre-fetch data in a ring.
166169
std::unique_ptr<Ring> ring_{new Ring};
167170
// Catching exception in pre-fetch threads to prevent segfault. Not always work though,
168171
// OOM error can be delayed due to lazy commit. On the bright side, if mmap is used then
@@ -180,10 +183,13 @@ class SparsePageSourceImpl : public BatchIteratorImpl<S> {
180183
}
181184
// An heuristic for number of pre-fetched batches. We can make it part of BatchParam
182185
// to let user adjust number of pre-fetched batches when needed.
183-
std::int32_t n_prefetches = std::max(nthreads_, 3);
186+
std::int32_t kPrefetches = 3;
187+
std::int32_t n_prefetches = std::min(nthreads_, kPrefetches);
188+
n_prefetches = std::max(n_prefetches, 1);
184189
std::int32_t n_prefetch_batches =
185190
std::min(static_cast<std::uint32_t>(n_prefetches), n_batches_);
186191
CHECK_GT(n_prefetch_batches, 0) << "total batches:" << n_batches_;
192+
CHECK_LE(n_prefetch_batches, kPrefetches);
187193
std::size_t fetch_it = count_;
188194

189195
exce_.Rethrow();
@@ -196,7 +202,7 @@ class SparsePageSourceImpl : public BatchIteratorImpl<S> {
196202
}
197203
auto const* self = this; // make sure it's const
198204
CHECK_LT(fetch_it, cache_info_->offset.size());
199-
ring_->at(fetch_it) = std::async(std::launch::async, [fetch_it, self, config, this]() {
205+
ring_->at(fetch_it) = this->workers_.Submit([fetch_it, self, config, this] {
200206
*GlobalConfigThreadLocalStore::Get() = config;
201207
auto page = std::make_shared<S>();
202208
this->exce_.Run([&] {
@@ -252,7 +258,8 @@ class SparsePageSourceImpl : public BatchIteratorImpl<S> {
252258
public:
253259
SparsePageSourceImpl(float missing, int nthreads, bst_feature_t n_features, uint32_t n_batches,
254260
std::shared_ptr<Cache> cache)
255-
: missing_{missing},
261+
: workers_{nthreads},
262+
missing_{missing},
256263
nthreads_{nthreads},
257264
n_features_{n_features},
258265
n_batches_{n_batches},

tests/cpp/common/test_threadpool.cc

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
/**
2+
* Copyright 2024, XGBoost Contributors
3+
*/
4+
#include <gtest/gtest.h>
5+
6+
#include <cstddef> // for size_t
7+
#include <cstdint> // for int32_t
8+
#include <future> // for future
9+
#include <thread> // for sleep_for, thread
10+
11+
#include "../../../src/common/threadpool.h"
12+
13+
namespace xgboost::common {
14+
TEST(ThreadPool, Basic) {
15+
std::int32_t n_threads = std::thread::hardware_concurrency();
16+
ThreadPool pool{n_threads};
17+
{
18+
auto fut = pool.Submit([] { return 3; });
19+
ASSERT_EQ(fut.get(), 3);
20+
}
21+
{
22+
auto fut = pool.Submit([] { return std::string{"ok"}; });
23+
ASSERT_EQ(fut.get(), "ok");
24+
}
25+
{
26+
std::vector<std::future<std::size_t>> futures;
27+
for (std::size_t i = 0; i < static_cast<std::size_t>(n_threads) * 16; ++i) {
28+
futures.emplace_back(pool.Submit([=] {
29+
std::this_thread::sleep_for(std::chrono::milliseconds{10});
30+
return i;
31+
}));
32+
}
33+
for (std::size_t i = 0; i < futures.size(); ++i) {
34+
ASSERT_EQ(futures[i].get(), i);
35+
}
36+
}
37+
{
38+
std::vector<std::future<std::size_t>> futures;
39+
for (std::size_t i = 0; i < static_cast<std::size_t>(n_threads) * 16; ++i) {
40+
futures.emplace_back(pool.Submit([=] {
41+
return i;
42+
}));
43+
}
44+
for (std::size_t i = 0; i < futures.size(); ++i) {
45+
ASSERT_EQ(futures[i].get(), i);
46+
}
47+
}
48+
}
49+
} // namespace xgboost::common

0 commit comments

Comments
 (0)