Skip to content

Commit 784a19e

Browse files
committed
fix some thread-safty issue and simplify threadpool
test=develop
1 parent 2256fae commit 784a19e

File tree

3 files changed

+18
-39
lines changed

3 files changed

+18
-39
lines changed

paddle/fluid/framework/threadpool.cc

Lines changed: 14 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,11 @@ ThreadPool* ThreadPool::GetInstance() {
3434
return threadpool_.get();
3535
}
3636

37+
void ThreadPool::Reset() {
38+
threadpool_.reset(nullptr);
39+
ThreadPool::Init();
40+
}
41+
3742
void ThreadPool::Init() {
3843
if (threadpool_.get() == nullptr) {
3944
// TODO(Yancey1989): specify the max threads number
@@ -59,6 +64,7 @@ ThreadPool::ThreadPool(int num_threads)
5964
ThreadPool::~ThreadPool() {
6065
{
6166
// notify all threads to stop running
67+
std::lock_guard<std::mutex> l(mutex_);
6268
running_ = false;
6369
scheduled_.notify_all();
6470
}
@@ -69,19 +75,18 @@ ThreadPool::~ThreadPool() {
6975
}
7076
}
7177

72-
void ThreadPool::Wait() {
73-
std::unique_lock<std::mutex> lock(mutex_);
74-
completed_.wait(lock, [=] { return Done() == true; });
75-
}
76-
7778
void ThreadPool::TaskLoop() {
78-
while (running_) {
79+
while (true) {
7980
std::unique_lock<std::mutex> lock(mutex_);
80-
scheduled_.wait(lock, [=] { return !tasks_.empty() || !running_; });
8181

82-
if (!running_) {
83-
break;
82+
scheduled_.wait(
83+
lock, [this] { return !this->tasks_.empty() || !this->running_; });
84+
85+
std::lock_guard<std::mutex> l(mutex_);
86+
if (!running_ || tasks_.empty()) {
87+
return;
8488
}
89+
8590
// pop a task from the task queue
8691
auto task = std::move(tasks_.front());
8792
tasks_.pop();
@@ -91,14 +96,6 @@ void ThreadPool::TaskLoop() {
9196

9297
// run the task
9398
task();
94-
95-
{
96-
std::unique_lock<std::mutex> lock(mutex_);
97-
++idle_threads_;
98-
if (Done()) {
99-
completed_.notify_all();
100-
}
101-
}
10299
}
103100
}
104101

paddle/fluid/framework/threadpool.h

Lines changed: 3 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -55,16 +55,10 @@ class ThreadPool {
5555
// Returns the singleton of ThreadPool.
5656
static ThreadPool* GetInstance();
5757

58-
~ThreadPool();
59-
60-
// Returns the number of threads created by the constructor.
61-
size_t Threads() const { return total_threads_; }
58+
// delete current thread pool and create a new one.
59+
static void Reset();
6260

63-
// Returns the number of currently idle threads.
64-
size_t IdleThreads() {
65-
std::unique_lock<std::mutex> lock(mutex_);
66-
return idle_threads_;
67-
}
61+
~ThreadPool();
6862

6963
// Run pushes a function to the task queue and returns a std::future
7064
// object. To wait for the completion of the task, call
@@ -94,25 +88,13 @@ class ThreadPool {
9488
});
9589
std::future<std::unique_ptr<platform::EnforceNotMet>> f = task.get_future();
9690
tasks_.push(std::move(task));
97-
lock.unlock();
9891
scheduled_.notify_one();
9992
return f;
10093
}
10194

102-
// Wait until all the tasks are completed.
103-
void Wait();
104-
10595
private:
10696
DISABLE_COPY_AND_ASSIGN(ThreadPool);
10797

108-
// If the task queue is empty and avaialbe is equal to the number of
109-
// threads, means that all tasks are completed. Note: this function
110-
// is not thread-safe. Returns true if all tasks are completed.
111-
// Note: don't delete the data member total_threads_ and use
112-
// threads_.size() instead; because you'd need to lock the mutex
113-
// before accessing threads_.
114-
bool Done() { return tasks_.empty() && idle_threads_ == total_threads_; }
115-
11698
// The constructor starts threads to run TaskLoop, which retrieves
11799
// and runs tasks from the queue.
118100
void TaskLoop();

paddle/fluid/framework/threadpool_test.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,6 @@ TEST(ThreadPool, ConcurrentRun) {
5252
for (auto& t : threads) {
5353
t.join();
5454
}
55-
pool->Wait();
55+
framework::ThreadPool::Reset();
5656
EXPECT_EQ(sum, ((n + 1) * n) / 2);
5757
}

0 commit comments

Comments
 (0)