Skip to content

Commit 6449fae

Browse files
authored
Merge pull request #14259 from jacquesqiao/optimize-thread-pool
Optimize thread pool
2 parents a9b5d42 + 4062f00 commit 6449fae

File tree

2 files changed

+26
-15
lines changed

2 files changed

+26
-15
lines changed

paddle/fluid/framework/threadpool.cc

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -57,10 +57,10 @@ ThreadPool::ThreadPool(int num_threads) : running_(true) {
5757
ThreadPool::~ThreadPool() {
5858
{
5959
// notify all threads to stop running
60-
std::lock_guard<std::mutex> l(mutex_);
60+
std::unique_lock<std::mutex> l(mutex_);
6161
running_ = false;
62-
scheduled_.notify_all();
6362
}
63+
scheduled_.notify_all();
6464

6565
for (auto& t : threads_) {
6666
t->join();
@@ -70,19 +70,25 @@ ThreadPool::~ThreadPool() {
7070

7171
void ThreadPool::TaskLoop() {
7272
while (true) {
73-
std::unique_lock<std::mutex> lock(mutex_);
73+
Task task;
7474

75-
scheduled_.wait(
76-
lock, [this] { return !this->tasks_.empty() || !this->running_; });
75+
{
76+
std::unique_lock<std::mutex> lock(mutex_);
77+
scheduled_.wait(
78+
lock, [this] { return !this->tasks_.empty() || !this->running_; });
7779

78-
if (!running_ || tasks_.empty()) {
79-
return;
80-
}
80+
if (!running_ && tasks_.empty()) {
81+
return;
82+
}
83+
84+
if (tasks_.empty()) {
85+
PADDLE_THROW("This thread has no task to Run");
86+
}
8187

82-
// pop a task from the task queue
83-
auto task = std::move(tasks_.front());
84-
tasks_.pop();
85-
lock.unlock();
88+
// pop a task from the task queue
89+
task = std::move(tasks_.front());
90+
tasks_.pop();
91+
}
8692

8793
// run the task
8894
task();

paddle/fluid/framework/threadpool.h

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ class ThreadPool {
5858
~ThreadPool();
5959

6060
// Run pushes a function to the task queue and returns a std::future
61-
// object. To wait for the completion of the task, call
61+
// object. To wait for the completion of the task, call
6262
// std::future::wait().
6363
template <typename Callback>
6464
std::future<void> Run(Callback fn) {
@@ -69,7 +69,6 @@ class ThreadPool {
6969
template <typename Callback>
7070
std::future<std::unique_ptr<platform::EnforceNotMet>> RunAndGetException(
7171
Callback fn) {
72-
std::unique_lock<std::mutex> lock(mutex_);
7372
Task task([fn]() -> std::unique_ptr<platform::EnforceNotMet> {
7473
try {
7574
fn();
@@ -84,7 +83,13 @@ class ThreadPool {
8483
return nullptr;
8584
});
8685
std::future<std::unique_ptr<platform::EnforceNotMet>> f = task.get_future();
87-
tasks_.push(std::move(task));
86+
{
87+
std::unique_lock<std::mutex> lock(mutex_);
88+
if (!running_) {
89+
PADDLE_THROW("enqueue on stopped ThreadPool");
90+
}
91+
tasks_.push(std::move(task));
92+
}
8893
scheduled_.notify_one();
8994
return f;
9095
}

0 commit comments

Comments
 (0)