Skip to content

Commit 1bdc726

Browse files
authored
Merge pull request #9578 from typhoonzero/threadpool_for_io
Multi stream thread pool
2 parents 2c552d4 + a08bf76 commit 1bdc726

File tree

5 files changed

+63
-27
lines changed

5 files changed

+63
-27
lines changed

paddle/fluid/framework/threadpool.cc

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,12 @@
1414

1515
#include "paddle/fluid/framework/threadpool.h"
1616

17+
#include "gflags/gflags.h"
1718
#include "paddle/fluid/platform/enforce.h"
1819

20+
DEFINE_int32(io_threadpool_size, 100,
21+
"number of threads used for doing IO, default 100");
22+
1923
namespace paddle {
2024
namespace framework {
2125

@@ -91,5 +95,20 @@ void ThreadPool::TaskLoop() {
9195
}
9296
}
9397

98+
std::unique_ptr<ThreadPool> ThreadPoolIO::io_threadpool_(nullptr);
99+
std::once_flag ThreadPoolIO::io_init_flag_;
100+
101+
ThreadPool* ThreadPoolIO::GetInstanceIO() {
102+
std::call_once(io_init_flag_, &ThreadPoolIO::InitIO);
103+
return io_threadpool_.get();
104+
}
105+
106+
void ThreadPoolIO::InitIO() {
107+
if (io_threadpool_.get() == nullptr) {
108+
// TODO(typhoonzero1986): make this configurable
109+
io_threadpool_.reset(new ThreadPool(FLAGS_io_threadpool_size));
110+
}
111+
}
112+
94113
} // namespace framework
95114
} // namespace paddle

paddle/fluid/framework/threadpool.h

Lines changed: 36 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,12 @@ limitations under the License. */
1414

1515
#pragma once
1616

17-
#include <condition_variable>
17+
#include <condition_variable> // NOLINT
1818
#include <functional>
19-
#include <future>
20-
#include <mutex>
19+
#include <future> // NOLINT
20+
#include <mutex> // NOLINT
2121
#include <queue>
22-
#include <thread>
22+
#include <thread> // NOLINT
2323
#include <vector>
2424
#include "glog/logging.h"
2525
#include "paddle/fluid/platform/enforce.h"
@@ -28,6 +28,22 @@ limitations under the License. */
2828
namespace paddle {
2929
namespace framework {
3030

31+
struct ExceptionHandler {
32+
mutable std::future<std::unique_ptr<platform::EnforceNotMet>> future_;
33+
explicit ExceptionHandler(
34+
std::future<std::unique_ptr<platform::EnforceNotMet>>&& f)
35+
: future_(std::move(f)) {}
36+
void operator()() const {
37+
auto ex = this->future_.get();
38+
if (ex != nullptr) {
39+
LOG(FATAL) << "The exception is thrown inside the thread pool. You "
40+
"should use RunAndGetException to handle the exception.\n"
41+
"The default exception handler is LOG(FATAL)."
42+
<< ex->what();
43+
}
44+
}
45+
};
46+
3147
// ThreadPool maintains a queue of tasks, and runs them using a fixed
3248
// number of threads.
3349
class ThreadPool {
@@ -87,22 +103,6 @@ class ThreadPool {
87103
void Wait();
88104

89105
private:
90-
struct ExceptionHandler {
91-
mutable std::future<std::unique_ptr<platform::EnforceNotMet>> future_;
92-
explicit ExceptionHandler(
93-
std::future<std::unique_ptr<platform::EnforceNotMet>>&& f)
94-
: future_(std::move(f)) {}
95-
void operator()() const {
96-
auto ex = this->future_.get();
97-
if (ex != nullptr) {
98-
LOG(FATAL) << "The exception is thrown inside the thread pool. You "
99-
"should use RunAndGetException to handle the exception.\n"
100-
"The default exception handler is LOG(FATAL)."
101-
<< ex->what();
102-
}
103-
}
104-
};
105-
106106
DISABLE_COPY_AND_ASSIGN(ThreadPool);
107107

108108
// If the task queue is empty and avaialbe is equal to the number of
@@ -135,6 +135,17 @@ class ThreadPool {
135135
std::condition_variable completed_;
136136
};
137137

138+
class ThreadPoolIO : ThreadPool {
139+
public:
140+
static ThreadPool* GetInstanceIO();
141+
static void InitIO();
142+
143+
private:
144+
// NOTE: threadpool in base will be inhereted here.
145+
static std::unique_ptr<ThreadPool> io_threadpool_;
146+
static std::once_flag io_init_flag_;
147+
};
148+
138149
// Run a function asynchronously.
139150
// NOTE: The function must return void. If the function need to return a value,
140151
// you can use lambda to capture a value pointer.
@@ -143,5 +154,10 @@ std::future<void> Async(Callback callback) {
143154
return ThreadPool::GetInstance()->Run(callback);
144155
}
145156

157+
template <typename Callback>
158+
std::future<void> AsyncIO(Callback callback) {
159+
return ThreadPoolIO::GetInstanceIO()->Run(callback);
160+
}
161+
146162
} // namespace framework
147163
} // namespace paddle

paddle/fluid/operators/detail/grpc_client.cc

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,8 @@ bool RPCClient::AsyncSendVariable(const std::string& ep,
3535
const framework::Scope* p_scope = &scope;
3636
const auto ch = GetChannel(ep_val);
3737

38-
framework::Async([var_name_val, p_ctx, ep_val, p_scope, time_out, ch, this] {
38+
framework::AsyncIO([var_name_val, p_ctx, ep_val, p_scope, time_out, ch,
39+
this] {
3940
auto* var = p_scope->FindVar(var_name_val);
4041

4142
::grpc::ByteBuffer req;
@@ -89,7 +90,8 @@ bool RPCClient::AsyncGetVariable(const std::string& ep,
8990
const framework::Scope* p_scope = &scope;
9091
const auto ch = GetChannel(ep_val);
9192

92-
framework::Async([var_name_val, ep_val, p_scope, p_ctx, time_out, ch, this] {
93+
framework::AsyncIO([var_name_val, ep_val, p_scope, p_ctx, time_out, ch,
94+
this] {
9395
// prepare input
9496
sendrecv::VariableMessage req;
9597
req.set_varname(var_name_val);
@@ -132,8 +134,8 @@ bool RPCClient::AsyncPrefetchVariable(const std::string& ep,
132134
const framework::Scope* p_scope = &scope;
133135
const auto ch = GetChannel(ep_val);
134136

135-
framework::Async([in_var_name_val, out_var_name_val, ep_val, p_scope, p_ctx,
136-
time_out, ch, this] {
137+
framework::AsyncIO([in_var_name_val, out_var_name_val, ep_val, p_scope, p_ctx,
138+
time_out, ch, this] {
137139
auto* var = p_scope->FindVar(in_var_name_val);
138140

139141
::grpc::ByteBuffer req;
@@ -196,7 +198,7 @@ bool RPCClient::Wait() {
196198
std::vector<std::future<void>> waits(req_count_);
197199

198200
for (int i = 0; i < req_count_; i++) {
199-
waits[i] = framework::Async([i, &a, this] { a[i] = Proceed(); });
201+
waits[i] = framework::AsyncIO([i, &a, this] { a[i] = Proceed(); });
200202
}
201203

202204
for (int i = 0; i < req_count_; i++) {

paddle/fluid/operators/detail/grpc_server.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -217,10 +217,10 @@ void AsyncGRPCServer::RunSyncUpdate() {
217217
std::function<void()> prefetch_register =
218218
std::bind(&AsyncGRPCServer::TryToRegisterNewPrefetchOne, this);
219219

220+
// TODO(wuyi): Run these "HandleRequest" in thread pool
220221
t_send_.reset(
221222
new std::thread(std::bind(&AsyncGRPCServer::HandleRequest, this,
222223
cq_send_.get(), "cq_send", send_register)));
223-
224224
t_get_.reset(
225225
new std::thread(std::bind(&AsyncGRPCServer::HandleRequest, this,
226226
cq_get_.get(), "cq_get", get_register)));

python/paddle/fluid/tests/book/test_recognize_digits.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,6 @@ def train_loop(main_program):
157157
for ip in pserver_ips.split(","):
158158
eplist.append(':'.join([ip, port]))
159159
pserver_endpoints = ",".join(eplist) # ip:port,ip:port...
160-
pserver_endpoints = os.getenv("PSERVERS")
161160
trainers = int(os.getenv("TRAINERS"))
162161
current_endpoint = os.getenv("POD_IP") + ":" + port
163162
trainer_id = int(os.getenv("PADDLE_INIT_TRAINER_ID"))

0 commit comments

Comments
 (0)