Skip to content

Commit f43be75

Browse files
committed
multi stream thread pool
1 parent 3fd9266 commit f43be75

File tree

3 files changed

+38
-5
lines changed

3 files changed

+38
-5
lines changed

paddle/fluid/framework/threadpool.cc

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,5 +91,20 @@ void ThreadPool::TaskLoop() {
9191
}
9292
}
9393

94+
std::unique_ptr<ThreadPool> MultiStreamThreadPool::io_threadpool_(nullptr);
95+
std::once_flag MultiStreamThreadPool::io_init_flag_;
96+
97+
MultiStreamThreadPool* MultiStreamThreadPool::GetInstanceIO() {
98+
std::call_once(io_init_flag_, &MultiStreamThreadPool::InitIO);
99+
return static_cast<MultiStreamThreadPool*>(io_threadpool_.get());
100+
}
101+
102+
void MultiStreamThreadPool::InitIO() {
103+
if (io_threadpool_.get() == nullptr) {
104+
// TODO(typhoonzero1986): make this configurable
105+
io_threadpool_.reset(new ThreadPool(100));
106+
}
107+
}
108+
94109
} // namespace framework
95110
} // namespace paddle

paddle/fluid/framework/threadpool.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,17 @@ class ThreadPool {
135135
std::condition_variable completed_;
136136
};
137137

138+
class MultiStreamThreadPool : ThreadPool {
139+
public:
140+
static MultiStreamThreadPool* GetInstanceIO();
141+
static void InitIO();
142+
143+
private:
144+
// NOTE: threadpool in base will be inhereted here.
145+
static std::unique_ptr<ThreadPool> io_threadpool_;
146+
static std::once_flag io_init_flag_;
147+
};
148+
138149
// Run a function asynchronously.
139150
// NOTE: The function must return void. If the function need to return a value,
140151
// you can use lambda to capture a value pointer.
@@ -143,5 +154,10 @@ std::future<void> Async(Callback callback) {
143154
return ThreadPool::GetInstance()->Run(callback);
144155
}
145156

157+
template <typename Callback>
158+
std::future<void> AsyncIO(Callback callback) {
159+
return MultiStreamThreadPool::GetInstanceIO()->Run(callback);
160+
}
161+
146162
} // namespace framework
147163
} // namespace paddle

paddle/fluid/operators/detail/grpc_client.cc

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,8 @@ bool RPCClient::AsyncSendVariable(const std::string& ep,
3333
const framework::Scope* p_scope = &scope;
3434
const auto ch = GetChannel(ep_val);
3535

36-
framework::Async([var_name_val, p_ctx, ep_val, p_scope, time_out, ch, this] {
36+
framework::AsyncIO([var_name_val, p_ctx, ep_val, p_scope, time_out, ch,
37+
this] {
3738
auto* var = p_scope->FindVar(var_name_val);
3839

3940
::grpc::ByteBuffer req;
@@ -88,7 +89,8 @@ bool RPCClient::AsyncGetVariable(const std::string& ep,
8889
const framework::Scope* p_scope = &scope;
8990
const auto ch = GetChannel(ep_val);
9091

91-
framework::Async([var_name_val, ep_val, p_scope, p_ctx, time_out, ch, this] {
92+
framework::AsyncIO([var_name_val, ep_val, p_scope, p_ctx, time_out, ch,
93+
this] {
9294
// prepare input
9395
sendrecv::VariableMessage req;
9496
req.set_varname(var_name_val);
@@ -131,8 +133,8 @@ bool RPCClient::AsyncPrefetchVariable(const std::string& ep,
131133
const framework::Scope* p_scope = &scope;
132134
const auto ch = GetChannel(ep_val);
133135

134-
framework::Async([in_var_name_val, out_var_name_val, ep_val, p_scope, p_ctx,
135-
time_out, ch, this] {
136+
framework::AsyncIO([in_var_name_val, out_var_name_val, ep_val, p_scope, p_ctx,
137+
time_out, ch, this] {
136138
auto* var = p_scope->FindVar(in_var_name_val);
137139

138140
::grpc::ByteBuffer req;
@@ -195,7 +197,7 @@ bool RPCClient::Wait() {
195197
std::vector<std::future<void>> waits(req_count_);
196198

197199
for (int i = 0; i < req_count_; i++) {
198-
waits[i] = framework::Async([i, &a, this] { a[i] = Proceed(); });
200+
waits[i] = framework::AsyncIO([i, &a, this] { a[i] = Proceed(); });
199201
}
200202

201203
for (int i = 0; i < req_count_; i++) {

0 commit comments

Comments
 (0)