Skip to content

Commit 9137806

Browse files
committed
feat(worker_group): allow worker threads to carry user-defined state
1 parent ee34a37 commit 9137806

File tree

2 files changed

+103
-68
lines changed

2 files changed

+103
-68
lines changed

include/dwarfs/internal/worker_group.h

Lines changed: 60 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828

2929
#pragma once
3030

31+
#include <any>
3132
#include <chrono>
3233
#include <concepts>
3334
#include <cstddef>
@@ -37,6 +38,7 @@
3738
#include <memory>
3839
#include <optional>
3940
#include <utility>
41+
#include <variant>
4042

4143
#include <folly/Function.h>
4244

@@ -47,6 +49,51 @@ class os_access;
4749

4850
namespace internal {
4951

52+
class thread_state {
53+
public:
54+
virtual ~thread_state() = default;
55+
56+
virtual void apply(std::any&& job_any) = 0;
57+
};
58+
59+
template <typename... Args>
60+
class basic_thread_state : public thread_state {
61+
public:
62+
using job_t = std::function<void(Args...)>;
63+
using moveonly_job_t = folly::Function<void(Args...)>;
64+
using any_job_t = std::variant<job_t, moveonly_job_t>;
65+
66+
explicit basic_thread_state(Args... args)
67+
: args_(std::make_tuple(std::forward<Args>(args)...)) {}
68+
69+
void apply(std::any&& job_any) override {
70+
std::visit(
71+
[this](auto&& j) {
72+
static_assert(std::is_rvalue_reference_v<decltype(j)>);
73+
auto job = std::forward<decltype(j)>(j);
74+
std::apply(job, args_);
75+
},
76+
std::move(
77+
*std::any_cast<std::shared_ptr<any_job_t>>(std::move(job_any))));
78+
}
79+
80+
static std::any make_job(job_t&& job) {
81+
return std::any(std::make_shared<any_job_t>(std::move(job)));
82+
}
83+
84+
static std::any make_job(moveonly_job_t&& job) {
85+
return std::any(std::make_shared<any_job_t>(std::move(job)));
86+
}
87+
88+
template <std::invocable<Args...> T>
89+
static std::any make_job(T&& job) {
90+
return make_job(moveonly_job_t(std::forward<T>(job)));
91+
}
92+
93+
private:
94+
std::tuple<Args...> args_;
95+
};
96+
5097
/**
5198
* A group of worker threads
5299
*
@@ -56,17 +103,20 @@ namespace internal {
56103
*/
57104
class worker_group {
58105
public:
59-
using job_t = std::function<void()>;
60-
using moveonly_job_t = folly::Function<void()>;
61-
62106
/**
63107
* Create a worker group
64108
*
65109
* \param num_workers Number of worker threads.
66110
*/
67-
explicit worker_group(
111+
worker_group(logger& lgr, os_access const& os, char const* group_name,
112+
size_t num_workers = 1,
113+
size_t max_queue_len = std::numeric_limits<size_t>::max(),
114+
int niceness = 0);
115+
116+
worker_group(
68117
logger& lgr, os_access const& os, char const* group_name,
69-
size_t num_workers = 1,
118+
size_t num_workers,
119+
std::function<std::unique_ptr<thread_state>(size_t)> thread_state_factory,
70120
size_t max_queue_len = std::numeric_limits<size_t>::max(),
71121
int niceness = 0);
72122

@@ -82,14 +132,10 @@ class worker_group {
82132
void wait() { impl_->wait(); }
83133
bool running() const { return impl_->running(); }
84134

85-
bool add_job(job_t&& job) { return impl_->add_job(std::move(job)); }
86-
bool add_job(moveonly_job_t&& job) {
87-
return impl_->add_moveonly_job(std::move(job));
88-
}
89-
90-
template <std::invocable T>
91-
bool add_job(T&& job) {
92-
return add_job(moveonly_job_t{std::forward<T>(job)});
135+
template <typename... Args>
136+
bool add_job(std::invocable<Args...> auto&& job) {
137+
return impl_->add_job(basic_thread_state<Args...>::make_job(
138+
std::forward<decltype(job)>(job)));
93139
}
94140

95141
size_t size() const { return impl_->size(); }
@@ -119,8 +165,7 @@ class worker_group {
119165
virtual void stop() = 0;
120166
virtual void wait() = 0;
121167
virtual bool running() const = 0;
122-
virtual bool add_job(job_t&& job) = 0;
123-
virtual bool add_moveonly_job(moveonly_job_t&& job) = 0;
168+
virtual bool add_job(std::any&& job) = 0;
124169
virtual size_t size() const = 0;
125170
virtual size_t queue_size() const = 0;
126171
virtual std::chrono::nanoseconds

src/internal/worker_group.cpp

Lines changed: 43 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@
3838
#include <string>
3939
#include <thread>
4040
#include <type_traits>
41-
#include <variant>
4241
#include <vector>
4342

4443
#include <fmt/format.h>
@@ -62,9 +61,11 @@ template <typename LoggerPolicy, typename Policy>
6261
class basic_worker_group final : public worker_group::impl, private Policy {
6362
public:
6463
template <typename... Args>
65-
basic_worker_group(logger& lgr, os_access const& os, char const* group_name,
66-
size_t num_workers, size_t max_queue_len,
67-
int niceness [[maybe_unused]], Args&&... args)
64+
basic_worker_group(
65+
logger& lgr, os_access const& os, char const* group_name,
66+
size_t num_workers,
67+
std::function<std::unique_ptr<thread_state>(size_t)> thread_state_factory,
68+
size_t max_queue_len, int niceness [[maybe_unused]], Args&&... args)
6869
: Policy(std::forward<Args>(args)...)
6970
, LOG_PROXY_INIT(lgr)
7071
, os_{os}
@@ -80,11 +81,12 @@ class basic_worker_group final : public worker_group::impl, private Policy {
8081
}
8182

8283
for (size_t i = 0; i < num_workers; ++i) {
83-
workers_.emplace_back([this, niceness, group_name, i] {
84-
folly::setThreadName(fmt::format("{}{}", group_name, i + 1));
85-
set_thread_niceness(niceness);
86-
do_work(niceness > 10);
87-
});
84+
workers_.emplace_back(
85+
[this, niceness, group_name, i, state = thread_state_factory(i)] {
86+
folly::setThreadName(fmt::format("{}{}", group_name, i + 1));
87+
set_thread_niceness(niceness);
88+
do_work(*state, niceness > 10);
89+
});
8890
}
8991

9092
check_set_affinity_from_enviroment(group_name);
@@ -146,19 +148,21 @@ class basic_worker_group final : public worker_group::impl, private Policy {
146148
*
147149
* \param job The job to add to the dispatcher.
148150
*/
149-
bool add_job(worker_group::job_t&& job) override {
150-
return add_job_impl(std::move(job));
151-
}
151+
bool add_job(std::any&& job) override {
152+
if (running_) {
153+
{
154+
std::unique_lock lock(mx_);
155+
queue_.wait(lock, [this] { return jobs_.size() < max_queue_len_; });
156+
jobs_.emplace(std::move(job));
157+
++pending_;
158+
}
152159

153-
/**
154-
* Add a new move-only job to the worker group
155-
*
156-
* The new job will be dispatched to the first available worker thread.
157-
*
158-
* \param job The job to add to the dispatcher.
159-
*/
160-
bool add_moveonly_job(worker_group::moveonly_job_t&& job) override {
161-
return add_job_impl(std::move(job));
160+
cond_.notify_one();
161+
162+
return true;
163+
}
164+
165+
return false;
162166
}
163167

164168
/**
@@ -219,26 +223,7 @@ class basic_worker_group final : public worker_group::impl, private Policy {
219223
}
220224

221225
private:
222-
using any_job_t =
223-
std::variant<worker_group::job_t, worker_group::moveonly_job_t>;
224-
using jobs_t = std::queue<any_job_t>;
225-
226-
bool add_job_impl(any_job_t&& job) {
227-
if (running_) {
228-
{
229-
std::unique_lock lock(mx_);
230-
queue_.wait(lock, [this] { return jobs_.size() < max_queue_len_; });
231-
jobs_.emplace(std::move(job));
232-
++pending_;
233-
}
234-
235-
cond_.notify_one();
236-
237-
return true;
238-
}
239-
240-
return false;
241-
}
226+
using jobs_t = std::queue<std::any>;
242227

243228
void check_set_affinity_from_enviroment(char const* group_name) {
244229
if (auto var = os_.getenv("DWARFS_WORKER_GROUP_AFFINITY")) {
@@ -276,12 +261,12 @@ class basic_worker_group final : public worker_group::impl, private Policy {
276261
}
277262
}
278263

279-
void do_work(bool is_background [[maybe_unused]]) {
264+
void do_work(thread_state& state, bool is_background [[maybe_unused]]) {
280265
#ifdef _WIN32
281266
auto hthr = ::GetCurrentThread();
282267
#endif
283268
for (;;) {
284-
any_job_t job;
269+
std::any job;
285270

286271
{
287272
std::unique_lock lock(mx_);
@@ -310,13 +295,7 @@ class basic_worker_group final : public worker_group::impl, private Policy {
310295
}
311296
#endif
312297
try {
313-
std::visit(
314-
[](auto&& j) {
315-
static_assert(std::is_rvalue_reference_v<decltype(j)>);
316-
auto job = std::forward<decltype(j)>(j);
317-
job();
318-
},
319-
std::move(job));
298+
state.apply(std::move(job));
320299
} catch (...) {
321300
LOG_FATAL << "exception thrown in worker thread: "
322301
<< exception_str(std::current_exception());
@@ -364,11 +343,22 @@ using default_worker_group = basic_worker_group<LoggerPolicy, no_policy>;
364343

365344
} // namespace
366345

346+
worker_group::worker_group(
347+
logger& lgr, os_access const& os, char const* group_name,
348+
size_t num_workers,
349+
std::function<std::unique_ptr<thread_state>(size_t)> thread_state_factory,
350+
size_t max_queue_len, int niceness)
351+
: impl_{make_unique_logging_object<impl, default_worker_group,
352+
logger_policies>(
353+
lgr, os, group_name, num_workers, thread_state_factory, max_queue_len,
354+
niceness)} {}
355+
367356
worker_group::worker_group(logger& lgr, os_access const& os,
368357
char const* group_name, size_t num_workers,
369358
size_t max_queue_len, int niceness)
370-
: impl_{make_unique_logging_object<impl, default_worker_group,
371-
logger_policies>(
372-
lgr, os, group_name, num_workers, max_queue_len, niceness)} {}
359+
: worker_group(
360+
lgr, os, group_name, num_workers,
361+
[](size_t) { return std::make_unique<basic_thread_state<>>(); },
362+
max_queue_len, niceness) {}
373363

374364
} // namespace dwarfs::internal

0 commit comments

Comments
 (0)