diff --git a/src/libutil-tests/meson.build b/src/libutil-tests/meson.build index 2772ab0604d..8133c1576d1 100644 --- a/src/libutil-tests/meson.build +++ b/src/libutil-tests/meson.build @@ -74,6 +74,7 @@ sources = files( 'strings.cc', 'suggestions.cc', 'terminal.cc', + 'thread-pool.cc', 'topo-sort.cc', 'url.cc', 'util.cc', diff --git a/src/libutil-tests/thread-pool.cc b/src/libutil-tests/thread-pool.cc new file mode 100644 index 00000000000..5a22e0e79ef --- /dev/null +++ b/src/libutil-tests/thread-pool.cc @@ -0,0 +1,33 @@ +#include "nix/util/thread-pool.hh" +#include + +namespace nix { + +using namespace std; + +TEST(threadpool, correctValue) +{ + ThreadPool pool(3); + int sum = 0; + std::mutex mtx; + for (int i = 0; i < 20; i++) { + pool.enqueue([&] { + std::lock_guard lock(mtx); + sum += 1; + }); + } + pool.process(); + ASSERT_EQ(sum, 20); +} + +TEST(threadpool, properlyHandlesDirectExceptions) +{ + struct TestExn + {}; + + ThreadPool pool(3); + pool.enqueue([&] { throw TestExn(); }); + EXPECT_THROW(pool.process(), TestExn); +} + +} // namespace nix \ No newline at end of file diff --git a/src/libutil/include/nix/util/closure.hh b/src/libutil/include/nix/util/closure.hh index 9e37b4cfb02..04872bbe5c6 100644 --- a/src/libutil/include/nix/util/closure.hh +++ b/src/libutil/include/nix/util/closure.hh @@ -4,6 +4,7 @@ #include #include #include "nix/util/sync.hh" +#include "nix/util/thread-pool.hh" using std::set; @@ -17,57 +18,32 @@ void computeClosure(const set startElts, set & res, GetEdgesAsync getEd { struct State { - size_t pending; set & res; - std::exception_ptr exc; }; - Sync state_(State{0, res, 0}); + Sync state_(State{res}); - std::condition_variable done; + ThreadPool pool(0); auto enqueue = [&](this auto & enqueue, const T & current) -> void { { auto state(state_.lock()); - if (state->exc) - return; if (!state->res.insert(current).second) return; - state->pending++; } - - getEdgesAsync(current, [&](std::promise> & prom) { - try { + pool.enqueue([&, current] { + getEdgesAsync(current, [&](std::promise> & prom) { auto children = prom.get_future().get(); for (auto & child : children) enqueue(child); - { - auto state(state_.lock()); - assert(state->pending); - if (!--state->pending) - done.notify_one(); - } - } catch (...) { - auto state(state_.lock()); - if (!state->exc) - state->exc = std::current_exception(); - assert(state->pending); - if (!--state->pending) - done.notify_one(); - }; + }); }); }; for (auto & startElt : startElts) enqueue(startElt); - { - auto state(state_.lock()); - while (state->pending) - state.wait(done); - if (state->exc) - std::rethrow_exception(state->exc); - } + pool.process(); } } // namespace nix