Skip to content

Commit 49313d4

Browse files
authored
Merge pull request #9548 from panyx0718/group_nccl_all_reduce
Group nccl all reduce and improve performance (~14% for 4 device resnext)
2 parents d2f9e19 + cf251eb commit 49313d4

File tree

11 files changed

+118
-35
lines changed

11 files changed

+118
-35
lines changed

paddle/fluid/framework/details/nccl_all_reduce_op_handle.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ void NCCLAllReduceOpHandle::RunImpl() {
7676
}
7777
}
7878

79-
std::string NCCLAllReduceOpHandle::Name() const { return "NCCL AllReduce"; }
79+
std::string NCCLAllReduceOpHandle::Name() const { return "nccl_all_reduce"; }
8080
} // namespace details
8181
} // namespace framework
8282
} // namespace paddle

paddle/fluid/framework/details/nccl_all_reduce_op_handle.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@
1414

1515
#pragma once
1616

17+
#include <string>
18+
#include <vector>
19+
1720
#include "paddle/fluid/framework/details/op_handle_base.h"
1821
#include "paddle/fluid/framework/lod_tensor.h"
1922
#include "paddle/fluid/framework/scope.h"
@@ -34,6 +37,10 @@ struct NCCLAllReduceOpHandle : public OpHandleBase {
3437

3538
std::string Name() const override;
3639

40+
// Delay and buffer nccl_all_reduce together can significantly increase
41+
// performance. Disable this feature by returning false.
42+
bool IsMultiDeviceTransfer() override { return true; };
43+
3744
protected:
3845
void RunImpl() override;
3946
};

paddle/fluid/framework/details/op_handle_base.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
// limitations under the License.
1414

1515
#pragma once
16+
#include <string>
17+
#include <vector>
1618

1719
#include "paddle/fluid/framework/details/var_handle.h"
1820
#include "paddle/fluid/platform/device_context.h"
@@ -53,6 +55,10 @@ class OpHandleBase {
5355

5456
void AddOutput(VarHandleBase *out);
5557

58+
// If the Op involves data transfer of multiple devices that
59+
// will likely block other computations.
60+
virtual bool IsMultiDeviceTransfer() { return false; }
61+
5662
protected:
5763
virtual void RunImpl() = 0;
5864
};

paddle/fluid/framework/details/threaded_ssa_graph_executor.cc

Lines changed: 50 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -23,22 +23,36 @@ ThreadedSSAGraphExecutor::ThreadedSSAGraphExecutor(
2323
size_t num_threads, bool use_event,
2424
const std::vector<Scope *> &local_scopes,
2525
const std::vector<platform::Place> &places,
26-
std::unique_ptr<SSAGraph> &&graph)
26+
std::unique_ptr<SSAGraph> &&graph, bool allow_op_delay)
2727
: SSAGraphExecutor(std::move(graph)),
2828
pool_(num_threads >= 2 ? new ::ThreadPool(num_threads) : nullptr),
2929
local_scopes_(local_scopes),
3030
places_(places),
3131
fetch_ctxs_(places),
32-
use_event_(use_event) {}
32+
use_event_(use_event),
33+
running_ops_(0),
34+
allow_op_delay_(allow_op_delay) {}
35+
36+
void ThreadedSSAGraphExecutor::RunDelayedOps(
37+
const std::unordered_set<OpHandleBase *> &delayed_ops) {
38+
for (auto op : delayed_ops) {
39+
op->Run(use_event_);
40+
}
41+
}
3342

3443
FeedFetchList ThreadedSSAGraphExecutor::Run(
3544
const std::vector<std::string> &fetch_tensors) {
3645
std::unordered_map<OpHandleBase *, size_t> pending_ops;
3746
std::unordered_set<VarHandleBase *> pending_vars;
38-
3947
BlockingQueue<VarHandleBase *> ready_vars;
40-
4148
std::unordered_set<OpHandleBase *> ready_ops;
49+
// For ops (e.g. nccl_all_reduce) that need to coordinate multiple
50+
// streams from multiple GPUs, it's faster to buffer them and schedule
51+
// together since we currently cannot overlap computation and memcpy streams.
52+
// Should revisit it if overlapping is available.
53+
std::unordered_set<OpHandleBase *> delayed_ops;
54+
std::unordered_set<OpHandleBase *> blocked_by_delayed_ops;
55+
std::unordered_set<VarHandleBase *> delayed_vars;
4256

4357
auto InsertPendingVar = [&pending_vars, &ready_vars](VarHandleBase &var) {
4458
pending_vars.insert(&var);
@@ -106,7 +120,14 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
106120

107121
auto run_all_ready_ops = [&] {
108122
for (auto *op : ready_ops) {
109-
RunOp(ready_vars, op);
123+
if (op->IsMultiDeviceTransfer() && allow_op_delay_) {
124+
delayed_ops.insert(op);
125+
delayed_vars.insert(op->outputs_.begin(), op->outputs_.end());
126+
ready_vars.Extend(op->outputs_);
127+
continue;
128+
}
129+
running_ops_++;
130+
RunOp(&ready_vars, op);
110131
}
111132
ready_ops.clear();
112133
};
@@ -118,13 +139,13 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
118139
}
119140

120141
// Step 3. Execution
121-
while (!pending_vars.empty()) {
142+
while (!pending_vars.empty() || !ready_ops.empty() || !delayed_ops.empty()) {
122143
// 1. Run All Ready ops
123144
run_all_ready_ops();
124145

125146
// 2. Find ready variable
126147
bool timeout;
127-
auto cur_ready_vars = ready_vars.PopAll(1000, &timeout);
148+
auto cur_ready_vars = ready_vars.PopAll(1, &timeout);
128149

129150
if (timeout) {
130151
if (exception_) {
@@ -141,13 +162,29 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
141162
auto &deps = pending_ops[op];
142163
--deps;
143164
if (deps == 0) {
144-
ready_ops.insert(op);
165+
if (delayed_vars.find(ready_var) != delayed_vars.end()) {
166+
blocked_by_delayed_ops.insert(op);
167+
} else {
168+
ready_ops.insert(op);
169+
}
145170
}
146171
}
147172
}
173+
// When there are no other ops to schedule, schedule buffered delayed
174+
// ops and unblock other ops.
175+
if (ready_ops.empty() && !delayed_ops.empty() && running_ops_ == 0) {
176+
RunDelayedOps(delayed_ops);
177+
delayed_ops.clear();
178+
for (auto *op : blocked_by_delayed_ops) {
179+
ready_ops.insert(op);
180+
}
181+
blocked_by_delayed_ops.clear();
182+
}
148183
// Keep loop until all vars are ready.
149184
}
150-
185+
PADDLE_ENFORCE(ready_ops.empty());
186+
PADDLE_ENFORCE(delayed_ops.empty());
187+
PADDLE_ENFORCE(blocked_by_delayed_ops.empty());
151188
++computation_count_;
152189

153190
auto sync_computation = [&] {
@@ -182,12 +219,13 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
182219
}
183220

184221
void ThreadedSSAGraphExecutor::RunOp(
185-
BlockingQueue<VarHandleBase *> &ready_var_q, details::OpHandleBase *op) {
186-
auto op_run = [&ready_var_q, op, this] {
222+
BlockingQueue<VarHandleBase *> *ready_var_q, details::OpHandleBase *op) {
223+
auto op_run = [ready_var_q, op, this] {
187224
try {
188225
VLOG(10) << op->Name() << " : " << op->DebugString();
189226
op->Run(use_event_);
190-
ready_var_q.Extend(op->outputs_);
227+
running_ops_--;
228+
ready_var_q->Extend(op->outputs_);
191229
} catch (platform::EnforceNotMet ex) {
192230
exception_.reset(new platform::EnforceNotMet(ex));
193231
} catch (...) {

paddle/fluid/framework/details/threaded_ssa_graph_executor.h

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,12 @@
1414

1515
#pragma once
1616

17-
#include <chrono>
17+
#include <deque>
18+
#include <string>
19+
#include <unordered_set>
20+
#include <utility>
21+
#include <vector>
22+
1823
#include <functional>
1924
#include "ThreadPool.h" // ThreadPool in thrird party
2025
#include "paddle/fluid/framework/details/ssa_graph_executor.h"
@@ -70,7 +75,8 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
7075
ThreadedSSAGraphExecutor(size_t num_threads, bool use_event,
7176
const std::vector<Scope *> &local_scopes,
7277
const std::vector<platform::Place> &places,
73-
std::unique_ptr<SSAGraph> &&graph);
78+
std::unique_ptr<SSAGraph> &&graph,
79+
bool allow_op_delay);
7480

7581
// Run a SSAGraph by a thread pool
7682
// Use topological sort algorithm
@@ -79,16 +85,20 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
7985
~ThreadedSSAGraphExecutor() {}
8086

8187
private:
82-
void RunOp(BlockingQueue<VarHandleBase *> &ready_var_q,
88+
void RunOp(BlockingQueue<VarHandleBase *> *ready_var_q,
8389
details::OpHandleBase *op);
8490

91+
void RunDelayedOps(const std::unordered_set<OpHandleBase *> &delayed_ops);
92+
8593
private:
8694
std::unique_ptr<::ThreadPool> pool_;
8795
std::vector<Scope *> local_scopes_;
8896
std::vector<platform::Place> places_;
8997
platform::DeviceContextPool fetch_ctxs_;
9098
const bool use_event_;
9199
std::unique_ptr<platform::EnforceNotMet> exception_;
100+
std::atomic<int> running_ops_;
101+
bool allow_op_delay_;
92102

93103
size_t computation_count_{0};
94104
size_t max_async_computation{100};

paddle/fluid/framework/parallel_executor.cc

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

1515
#include "paddle/fluid/framework/parallel_executor.h"
16+
#include "paddle/fluid/platform/profiler.h"
1617

1718
#include <string>
1819
#include <vector>
@@ -47,7 +48,7 @@ ParallelExecutor::ParallelExecutor(
4748
const std::vector<platform::Place> &places,
4849
const std::unordered_set<std::string> &params,
4950
const ProgramDesc &startup_program, const ProgramDesc &main_program,
50-
const std::string &loss_var_name, Scope *scope)
51+
const std::string &loss_var_name, Scope *scope, bool allow_op_delay)
5152
: member_(new ParallelExecutorPrivate(places)) {
5253
member_->global_scope_ = scope;
5354

@@ -82,8 +83,8 @@ ParallelExecutor::ParallelExecutor(
8283
auto graph = builder.Build(main_program);
8384

8485
member_->executor_.reset(new details::ThreadedSSAGraphExecutor(
85-
num_threads, use_event, member_->local_scopes_, places,
86-
std::move(graph)));
86+
num_threads, use_event, member_->local_scopes_, places, std::move(graph),
87+
allow_op_delay));
8788

8889
// Step 3. Create vars in each scope;
8990
for (auto *scope : member_->local_scopes_) {
@@ -151,6 +152,7 @@ void ParallelExecutor::BCastParamsToGPUs(
151152

152153
void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors,
153154
const std::string &fetched_var_name) {
155+
platform::RecordBlock b(0);
154156
auto fetch_data = member_->executor_->Run(fetch_tensors);
155157
*member_->global_scope_->Var(fetched_var_name)->GetMutable<FeedFetchList>() =
156158
fetch_data;

paddle/fluid/framework/parallel_executor.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,9 @@ limitations under the License. */
1414

1515
#pragma once
1616

17-
#include <future>
17+
#include <string>
1818
#include <unordered_set>
19+
#include <vector>
1920
#include "paddle/fluid/framework/executor.h"
2021
#include "paddle/fluid/framework/op_info.h"
2122
#include "paddle/fluid/framework/program_desc.h"
@@ -37,7 +38,8 @@ class ParallelExecutor {
3738
const std::unordered_set<std::string>& params,
3839
const ProgramDesc& startup_program,
3940
const ProgramDesc& main_program,
40-
const std::string& loss_var_name, Scope* scope);
41+
const std::string& loss_var_name, Scope* scope,
42+
bool allow_op_delay);
4143

4244
void Run(const std::vector<std::string>& fetch_tensors,
4345
const std::string& fetched_var_name = "fetched_var");

paddle/fluid/pybind/pybind.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -504,10 +504,10 @@ All parameter, weight, gradient are variables in Paddle.
504504
const std::unordered_set<std::string> &params,
505505
const ProgramDesc &startup_program,
506506
const ProgramDesc &main_program, const std::string &loss_var_name,
507-
Scope *scope) {
507+
Scope *scope, bool allow_op_delay) {
508508
new (&self) ParallelExecutor(num_threads, use_event, places,
509509
params, startup_program, main_program,
510-
loss_var_name, scope);
510+
loss_var_name, scope, allow_op_delay);
511511
})
512512
.def("run", &ParallelExecutor::Run);
513513

python/paddle/fluid/parallel_executor.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,11 @@
2121

2222

2323
class ParallelExecutor(object):
24-
def __init__(self, loss_name, use_cuda, num_threads=None):
24+
def __init__(self,
25+
loss_name,
26+
use_cuda,
27+
num_threads=None,
28+
allow_op_delay=False):
2529
places = []
2630
if use_cuda:
2731
for i in xrange(core.get_cuda_device_count()):
@@ -35,7 +39,12 @@ def __init__(self, loss_name, use_cuda, num_threads=None):
3539
places.append(p)
3640

3741
if num_threads is None:
38-
num_threads = min(len(places) * 2, multiprocessing.cpu_count())
42+
if use_cuda:
43+
# Experiments on se-resnext shows that too many threads hurt
44+
# performance. Worth tunning for other models in the future.
45+
num_threads = len(places)
46+
else:
47+
min(len(places) * 2, multiprocessing.cpu_count())
3948

4049
startup = framework.default_startup_program()
4150
main = framework.default_main_program()
@@ -52,7 +61,8 @@ def __init__(self, loss_name, use_cuda, num_threads=None):
5261
startup.desc,
5362
main.desc,
5463
loss_name,
55-
scope)
64+
scope,
65+
allow_op_delay)
5666
self.scope = scope
5767

5868
def run(self, fetch_list):

python/paddle/fluid/tests/unittests/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ function(py_test_modules TARGET_NAME)
2929
endfunction()
3030

3131
# test time consuming OPs in a separate process for expliot parallism
32+
list(REMOVE_ITEM TEST_OPS test_parallel_executor)
3233
list(REMOVE_ITEM TEST_OPS test_warpctc_op)
3334
list(REMOVE_ITEM TEST_OPS test_dyn_rnn)
3435
list(REMOVE_ITEM TEST_OPS test_mul_op)
@@ -64,6 +65,7 @@ else()
6465
endif(WITH_FAST_BUNDLE_TEST)
6566

6667
# tests with high overhead
68+
py_test_modules(test_parallel_executor MODULES test_parallel_executor)
6769
py_test_modules(test_warpctc_op MODULES test_warpctc_op ENVS FLAGS_warpctc_dir=${WARPCTC_LIB_DIR})
6870
py_test_modules(test_train_dyn_rnn MODULES test_dyn_rnn)
6971
py_test_modules(test_mul_op MODULES test_mul_op)

0 commit comments

Comments
 (0)