Skip to content

Commit e533a4b

Browse files
committed
Merge branch 'develop' of github.com:PaddlePaddle/Paddle into overlap_memcpy_with_dist
2 parents cb38615 + c36dd3b commit e533a4b

21 files changed

+244
-81
lines changed

paddle/fluid/framework/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ cc_library(executor SRCS executor.cc DEPS op_registry device_context scope
8787
framework_proto glog lod_rank_table feed_fetch_method)
8888

8989

90-
cc_library(parallel_executor SRCS parallel_executor.cc DEPS multi_devices_graph_builder threaded_ssa_graph_executor)
90+
cc_library(parallel_executor SRCS parallel_executor.cc DEPS multi_devices_graph_builder threaded_ssa_graph_executor scope_buffered_ssa_graph_executor)
9191

9292
cc_library(prune SRCS prune.cc DEPS framework_proto)
9393
cc_test(prune_test SRCS prune_test.cc DEPS op_info prune recurrent_op device_context)

paddle/fluid/framework/details/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,5 +36,6 @@ cc_test(broadcast_op_test SRCS broadcast_op_handle_test.cc DEPS var_handle op_ha
3636
device_context broadcast_op_handle)
3737
cc_test(gather_op_test SRCS gather_op_handle_test.cc DEPS var_handle op_handle_base scope ddim memory
3838
device_context gather_op_handle)
39+
cc_library(scope_buffered_ssa_graph_executor SRCS scope_buffered_ssa_graph_executor.cc DEPS ssa_graph_executor)
3940
#cc_test(reduce_op_handle_test SRCS reduce_op_handle_test.cc DEPS var_handle op_handle_base scope ddim memory
4041
# device_context reduce_op_handle )

paddle/fluid/framework/details/execution_strategy.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ struct ExecutionStrategy {
2222
size_t num_threads_{0};
2323
bool use_event_{true};
2424
bool allow_op_delay_{false};
25+
size_t num_iteration_per_drop_scope_{100};
2526
};
2627

2728
} // namespace details
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h"
16+
#include <string>
17+
#include <vector>
18+
#include "paddle/fluid/framework/executor.h"
19+
20+
namespace paddle {
21+
namespace framework {
22+
namespace details {
23+
ScopeBufferedSSAGraphExecutor::ScopeBufferedSSAGraphExecutor(
24+
ExecutionStrategy strategy, std::vector<Scope *> local_scopes,
25+
std::vector<VariableInfo> var_infos, std::vector<platform::Place> places,
26+
std::unique_ptr<SSAGraphExecutor> &&underlying_executor)
27+
: strategy_(std::move(strategy)),
28+
underlying_executor_(std::move(underlying_executor)),
29+
local_scopes_(std::move(local_scopes)),
30+
var_infos_(std::move(var_infos)),
31+
places_(std::move(places)) {}
32+
33+
FeedFetchList ScopeBufferedSSAGraphExecutor::Run(
34+
const std::vector<std::string> &fetch_tensors) {
35+
if (drop_scope_counter_ == 0) {
36+
// Create local scopes.
37+
for (auto it = local_scopes_.rbegin(); it != local_scopes_.rend(); ++it) {
38+
auto &scope = *it;
39+
Scope &local_scope = scope->NewScope();
40+
*scope->Var(details::kLocalExecScopeName)->GetMutable<Scope *>() =
41+
&local_scope;
42+
43+
for (auto &info : var_infos_) {
44+
if (scope->FindVar(info.name_) != nullptr) {
45+
continue;
46+
}
47+
48+
if (info.persistable_) { // Persistable
49+
InitializeVariable(scope->Var(info.name_), info.type_);
50+
} else {
51+
InitializeVariable(local_scope.Var(info.name_), info.type_);
52+
}
53+
}
54+
}
55+
}
56+
57+
auto fetch_data = underlying_executor_->Run(fetch_tensors);
58+
drop_scope_counter_ += 1;
59+
if (!fetch_tensors.empty() ||
60+
drop_scope_counter_ == strategy_.num_iteration_per_drop_scope_) {
61+
drop_scope_counter_ = 0;
62+
// Wait All computational streams
63+
for (auto p : places_) {
64+
platform::DeviceContextPool::Instance().Get(p)->Wait();
65+
}
66+
for (auto &scope : local_scopes_) {
67+
auto &local_scope =
68+
*scope->Var(details::kLocalExecScopeName)->GetMutable<Scope *>();
69+
scope->DeleteScope(local_scope);
70+
}
71+
}
72+
return fetch_data;
73+
}
74+
} // namespace details
75+
} // namespace framework
76+
} // namespace paddle
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
17+
#include <memory>
18+
#include <string>
19+
#include <vector>
20+
#include "paddle/fluid/framework/details/execution_strategy.h"
21+
#include "paddle/fluid/framework/details/ssa_graph_executor.h"
22+
#include "paddle/fluid/framework/scope.h"
23+
#include "paddle/fluid/platform/place.h"
24+
namespace paddle {
25+
namespace framework {
26+
namespace details {
27+
28+
struct VariableInfo {
29+
std::string name_;
30+
proto::VarType::Type type_;
31+
bool persistable_;
32+
};
33+
34+
class ScopeBufferedSSAGraphExecutor : public SSAGraphExecutor {
35+
public:
36+
ScopeBufferedSSAGraphExecutor(
37+
ExecutionStrategy strategy, std::vector<Scope*> local_scopes,
38+
std::vector<VariableInfo> var_infos, std::vector<platform::Place> places,
39+
std::unique_ptr<SSAGraphExecutor>&& underlying_executor);
40+
FeedFetchList Run(const std::vector<std::string>& fetch_tensors) override;
41+
42+
private:
43+
size_t drop_scope_counter_{0};
44+
45+
ExecutionStrategy strategy_;
46+
std::unique_ptr<SSAGraphExecutor> underlying_executor_;
47+
std::vector<Scope*> local_scopes_;
48+
std::vector<VariableInfo> var_infos_;
49+
std::vector<platform::Place> places_;
50+
};
51+
} // namespace details
52+
} // namespace framework
53+
} // namespace paddle

paddle/fluid/framework/details/ssa_graph_executor.cc

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,6 @@
1717
namespace paddle {
1818
namespace framework {
1919
namespace details {
20-
21-
SSAGraphExecutor::SSAGraphExecutor(std::unique_ptr<SSAGraph> &&graph)
22-
: graph_(std::move(graph)) {}
23-
2420
SSAGraphExecutor::~SSAGraphExecutor() {}
2521

2622
} // namespace details

paddle/fluid/framework/details/ssa_graph_executor.h

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,15 +28,11 @@ class SSAGraphExecutor {
2828
DISABLE_COPY_AND_ASSIGN(SSAGraphExecutor);
2929

3030
public:
31-
// Steal graph inside
32-
explicit SSAGraphExecutor(std::unique_ptr<SSAGraph> &&graph);
31+
SSAGraphExecutor() {}
3332

3433
virtual ~SSAGraphExecutor();
3534

3635
virtual FeedFetchList Run(const std::vector<std::string> &fetch_tensors) = 0;
37-
38-
protected:
39-
std::unique_ptr<SSAGraph> graph_;
4036
};
4137
} // namespace details
4238
} // namespace framework

paddle/fluid/framework/details/threaded_ssa_graph_executor.cc

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ ThreadedSSAGraphExecutor::ThreadedSSAGraphExecutor(
2121
const ExecutionStrategy &strategy, const std::vector<Scope *> &local_scopes,
2222
const std::vector<platform::Place> &places,
2323
std::unique_ptr<SSAGraph> &&graph)
24-
: SSAGraphExecutor(std::move(graph)),
24+
: graph_(std::move(graph)),
2525
pool_(strategy.num_threads_ >= 2 ? new ::ThreadPool(strategy.num_threads_)
2626
: nullptr),
2727
local_scopes_(local_scopes),
@@ -189,7 +189,9 @@ void ThreadedSSAGraphExecutor::RunOp(
189189
BlockingQueue<VarHandleBase *> *ready_var_q, details::OpHandleBase *op) {
190190
auto op_run = [ready_var_q, op, this] {
191191
try {
192-
VLOG(10) << op << " " << op->Name() << " : " << op->DebugString();
192+
if (VLOG_IS_ON(10)) {
193+
VLOG(10) << op << " " << op->Name() << " : " << op->DebugString();
194+
}
193195
op->Run(strategy_.use_event_);
194196
VLOG(10) << op << " " << op->Name() << " Done ";
195197
running_ops_--;

paddle/fluid/framework/details/threaded_ssa_graph_executor.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
5151
details::OpHandleBase *op);
5252

5353
private:
54+
std::unique_ptr<SSAGraph> graph_;
5455
std::unique_ptr<::ThreadPool> pool_;
5556
std::vector<Scope *> local_scopes_;
5657
std::vector<platform::Place> places_;

paddle/fluid/framework/parallel_executor.cc

Lines changed: 16 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ limitations under the License. */
2222
#include "paddle/fluid/platform/nccl_helper.h"
2323
#endif
2424

25+
#include "paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h"
2526
#include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h"
2627
#include "paddle/fluid/platform/profiler.h"
2728

@@ -41,8 +42,6 @@ class ParallelExecutorPrivate {
4142
#ifdef PADDLE_WITH_CUDA
4243
std::unique_ptr<platform::NCCLContextMap> nccl_ctxs_;
4344
#endif
44-
45-
std::vector<std::tuple<std::string, proto::VarType::Type, bool>> var_types_;
4645
bool own_local_scope;
4746
};
4847

@@ -91,9 +90,18 @@ ParallelExecutor::ParallelExecutor(
9190
local_scopes.empty()) { // Is CUDA
9291
BCastParamsToGPUs(bcast_vars);
9392
}
94-
// Startup Program has been run. All local scopes has correct parameters.
93+
// Startup Program has been run. All local scopes has correct parameters.
94+
95+
// Step 2. Create vars in each scope;
96+
std::vector<details::VariableInfo> var_infos;
97+
for (auto *var : main_program.Block(0).AllVars()) {
98+
var_infos.emplace_back();
99+
var_infos.back().name_ = var->Name();
100+
var_infos.back().type_ = var->GetType();
101+
var_infos.back().persistable_ = var->Persistable();
102+
}
95103

96-
// Step 2. Convert main_program to SSA form and dependency graph. Also, insert
104+
// Step 3. Convert main_program to SSA form and dependency graph. Also, insert
97105
// ncclOp
98106
#ifdef PADDLE_WITH_CUDA
99107
builder_.reset(new details::MultiDevSSAGraphBuilder(
@@ -106,16 +114,14 @@ ParallelExecutor::ParallelExecutor(
106114
build_strategy));
107115

108116
#endif
109-
auto graph = builder_.get()->Build(main_program);
117+
auto graph = builder_->Build(main_program);
110118

111119
member_->executor_.reset(new details::ThreadedSSAGraphExecutor(
112120
exec_strategy, member_->local_scopes_, places, std::move(graph)));
113121

114-
// Step 3. Create vars in each scope;
115-
for (auto *var : main_program.Block(0).AllVars()) {
116-
member_->var_types_.emplace_back(var->Name(), var->GetType(),
117-
var->Persistable());
118-
}
122+
member_->executor_.reset(new details::ScopeBufferedSSAGraphExecutor(
123+
exec_strategy, member_->local_scopes_, std::move(var_infos),
124+
member_->places_, std::move(member_->executor_)));
119125
}
120126

121127
void ParallelExecutor::BCastParamsToGPUs(
@@ -178,42 +184,9 @@ void ParallelExecutor::BCastParamsToGPUs(
178184
void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors,
179185
const std::string &fetched_var_name) {
180186
platform::RecordBlock b(0);
181-
// Create local scopes.
182-
for (auto it = member_->local_scopes_.rbegin();
183-
it != member_->local_scopes_.rend(); ++it) {
184-
auto &scope = *it;
185-
Scope &local_scope = scope->NewScope();
186-
*scope->Var(details::kLocalExecScopeName)->GetMutable<Scope *>() =
187-
&local_scope;
188-
189-
for (auto &name_type_pair : member_->var_types_) {
190-
if (scope->FindVar(std::get<0>(name_type_pair)) != nullptr) {
191-
continue;
192-
}
193-
194-
if (std::get<2>(name_type_pair)) { // Persistable
195-
InitializeVariable(scope->Var(std::get<0>(name_type_pair)),
196-
std::get<1>(name_type_pair));
197-
} else {
198-
InitializeVariable(local_scope.Var(std::get<0>(name_type_pair)),
199-
std::get<1>(name_type_pair));
200-
}
201-
}
202-
}
203-
204187
auto fetch_data = member_->executor_->Run(fetch_tensors);
205188
*member_->global_scope_->Var(fetched_var_name)->GetMutable<FeedFetchList>() =
206189
fetch_data;
207-
208-
// Wait All computational streams
209-
for (auto p : member_->places_) {
210-
platform::DeviceContextPool::Instance().Get(p)->Wait();
211-
}
212-
for (auto &scope : member_->local_scopes_) {
213-
auto &local_scope =
214-
*scope->Var(details::kLocalExecScopeName)->GetMutable<Scope *>();
215-
scope->DeleteScope(local_scope);
216-
}
217190
}
218191

219192
void ParallelExecutor::FeedTensorsIntoLocalScopes(

0 commit comments

Comments
 (0)