Skip to content

Commit 26cfc63

Browse files
committed
multi stream thread pool
1 parent 7050039 commit 26cfc63

File tree

4 files changed

+13
-10
lines changed

4 files changed

+13
-10
lines changed

paddle/fluid/framework/threadpool.cc

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,12 @@
1414

1515
#include "paddle/fluid/framework/threadpool.h"
1616

17+
#include "gflags/gflags.h"
1718
#include "paddle/fluid/platform/enforce.h"
1819

20+
DEFINE_int32(io_threadpool_size, 100,
21+
"number of threads used for doing IO, default 100");
22+
1923
namespace paddle {
2024
namespace framework {
2125

@@ -94,15 +98,15 @@ void ThreadPool::TaskLoop() {
9498
std::unique_ptr<ThreadPool> MultiStreamThreadPool::io_threadpool_(nullptr);
9599
std::once_flag MultiStreamThreadPool::io_init_flag_;
96100

97-
MultiStreamThreadPool* MultiStreamThreadPool::GetInstanceIO() {
101+
ThreadPool* MultiStreamThreadPool::GetInstanceIO() {
98102
std::call_once(io_init_flag_, &MultiStreamThreadPool::InitIO);
99-
return static_cast<MultiStreamThreadPool*>(io_threadpool_.get());
103+
return io_threadpool_.get();
100104
}
101105

102106
void MultiStreamThreadPool::InitIO() {
103107
if (io_threadpool_.get() == nullptr) {
104108
// TODO(typhoonzero1986): make this configurable
105-
io_threadpool_.reset(new ThreadPool(100));
109+
io_threadpool_.reset(new ThreadPool(FLAGS_io_threadpool_size));
106110
}
107111
}
108112

paddle/fluid/framework/threadpool.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,12 @@ limitations under the License. */
1414

1515
#pragma once
1616

17-
#include <condition_variable>
17+
#include <condition_variable> // NOLINT
1818
#include <functional>
19-
#include <future>
20-
#include <mutex>
19+
#include <future> // NOLINT
20+
#include <mutex> // NOLINT
2121
#include <queue>
22-
#include <thread>
22+
#include <thread> // NOLINT
2323
#include <vector>
2424
#include "glog/logging.h"
2525
#include "paddle/fluid/platform/enforce.h"
@@ -137,7 +137,7 @@ class ThreadPool {
137137

138138
class MultiStreamThreadPool : ThreadPool {
139139
public:
140-
static MultiStreamThreadPool* GetInstanceIO();
140+
static ThreadPool* GetInstanceIO();
141141
static void InitIO();
142142

143143
private:

paddle/fluid/operators/detail/grpc_server.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -216,10 +216,10 @@ void AsyncGRPCServer::RunSyncUpdate() {
216216
std::function<void()> prefetch_register =
217217
std::bind(&AsyncGRPCServer::TryToRegisterNewPrefetchOne, this);
218218

219+
// TODO(wuyi): Run these "HandleRequest" in thread pool
219220
t_send_.reset(
220221
new std::thread(std::bind(&AsyncGRPCServer::HandleRequest, this,
221222
cq_send_.get(), "cq_send", send_register)));
222-
223223
t_get_.reset(
224224
new std::thread(std::bind(&AsyncGRPCServer::HandleRequest, this,
225225
cq_get_.get(), "cq_get", get_register)));

python/paddle/fluid/tests/book/test_recognize_digits.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,6 @@ def train_loop(main_program):
157157
for ip in pserver_ips.split(","):
158158
eplist.append(':'.join([ip, port]))
159159
pserver_endpoints = ",".join(eplist) # ip:port,ip:port...
160-
pserver_endpoints = os.getenv("PSERVERS")
161160
trainers = int(os.getenv("TRAINERS"))
162161
current_endpoint = os.getenv("POD_IP") + ":" + port
163162
trainer_id = int(os.getenv("PADDLE_INIT_TRAINER_ID"))

0 commit comments

Comments
 (0)