Skip to content

Commit b6c8701

Browse files
committed
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into feature/extract_tensor
2 parents fc9f2d2 + 106ee9d commit b6c8701

24 files changed

+183
-82
lines changed

paddle/fluid/framework/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ cc_library(executor SRCS executor.cc DEPS op_registry device_context scope
8787
framework_proto glog lod_rank_table feed_fetch_method)
8888

8989

90-
cc_library(parallel_executor SRCS parallel_executor.cc DEPS multi_devices_graph_builder threaded_ssa_graph_executor)
90+
cc_library(parallel_executor SRCS parallel_executor.cc DEPS multi_devices_graph_builder threaded_ssa_graph_executor scope_buffered_ssa_graph_executor)
9191

9292
cc_library(prune SRCS prune.cc DEPS framework_proto)
9393
cc_test(prune_test SRCS prune_test.cc DEPS op_info prune recurrent_op device_context)

paddle/fluid/framework/details/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,5 +36,6 @@ cc_test(broadcast_op_test SRCS broadcast_op_handle_test.cc DEPS var_handle op_ha
3636
device_context broadcast_op_handle)
3737
cc_test(gather_op_test SRCS gather_op_handle_test.cc DEPS var_handle op_handle_base scope ddim memory
3838
device_context gather_op_handle)
39+
cc_library(scope_buffered_ssa_graph_executor SRCS scope_buffered_ssa_graph_executor.cc DEPS ssa_graph_executor)
3940
#cc_test(reduce_op_handle_test SRCS reduce_op_handle_test.cc DEPS var_handle op_handle_base scope ddim memory
4041
# device_context reduce_op_handle )

paddle/fluid/framework/details/execution_strategy.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ struct ExecutionStrategy {
2222
size_t num_threads_{0};
2323
bool use_event_{true};
2424
bool allow_op_delay_{false};
25+
size_t num_iteration_per_drop_scope_{100};
2526
};
2627

2728
} // namespace details
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h"
16+
#include <string>
17+
#include <vector>
18+
#include "paddle/fluid/framework/executor.h"
19+
20+
namespace paddle {
21+
namespace framework {
22+
namespace details {
23+
ScopeBufferedSSAGraphExecutor::ScopeBufferedSSAGraphExecutor(
24+
ExecutionStrategy strategy, std::vector<Scope *> local_scopes,
25+
std::vector<VariableInfo> var_infos, std::vector<platform::Place> places,
26+
std::unique_ptr<SSAGraphExecutor> &&underlying_executor)
27+
: strategy_(std::move(strategy)),
28+
underlying_executor_(std::move(underlying_executor)),
29+
local_scopes_(std::move(local_scopes)),
30+
var_infos_(std::move(var_infos)),
31+
places_(std::move(places)) {}
32+
33+
FeedFetchList ScopeBufferedSSAGraphExecutor::Run(
34+
const std::vector<std::string> &fetch_tensors) {
35+
if (drop_scope_counter_ == 0) {
36+
// Create local scopes.
37+
for (auto it = local_scopes_.rbegin(); it != local_scopes_.rend(); ++it) {
38+
auto &scope = *it;
39+
Scope &local_scope = scope->NewScope();
40+
*scope->Var(details::kLocalExecScopeName)->GetMutable<Scope *>() =
41+
&local_scope;
42+
43+
for (auto &info : var_infos_) {
44+
if (scope->FindVar(info.name_) != nullptr) {
45+
continue;
46+
}
47+
48+
if (info.persistable_) { // Persistable
49+
InitializeVariable(scope->Var(info.name_), info.type_);
50+
} else {
51+
InitializeVariable(local_scope.Var(info.name_), info.type_);
52+
}
53+
}
54+
}
55+
}
56+
57+
auto fetch_data = underlying_executor_->Run(fetch_tensors);
58+
drop_scope_counter_ += 1;
59+
if (!fetch_tensors.empty() ||
60+
drop_scope_counter_ == strategy_.num_iteration_per_drop_scope_) {
61+
drop_scope_counter_ = 0;
62+
// Wait All computational streams
63+
for (auto p : places_) {
64+
platform::DeviceContextPool::Instance().Get(p)->Wait();
65+
}
66+
for (auto &scope : local_scopes_) {
67+
auto &local_scope =
68+
*scope->Var(details::kLocalExecScopeName)->GetMutable<Scope *>();
69+
scope->DeleteScope(local_scope);
70+
}
71+
}
72+
return fetch_data;
73+
}
74+
} // namespace details
75+
} // namespace framework
76+
} // namespace paddle
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
17+
#include <memory>
18+
#include <string>
19+
#include <vector>
20+
#include "paddle/fluid/framework/details/execution_strategy.h"
21+
#include "paddle/fluid/framework/details/ssa_graph_executor.h"
22+
#include "paddle/fluid/framework/scope.h"
23+
#include "paddle/fluid/platform/place.h"
24+
namespace paddle {
25+
namespace framework {
26+
namespace details {
27+
28+
struct VariableInfo {
29+
std::string name_;
30+
proto::VarType::Type type_;
31+
bool persistable_;
32+
};
33+
34+
class ScopeBufferedSSAGraphExecutor : public SSAGraphExecutor {
35+
public:
36+
ScopeBufferedSSAGraphExecutor(
37+
ExecutionStrategy strategy, std::vector<Scope*> local_scopes,
38+
std::vector<VariableInfo> var_infos, std::vector<platform::Place> places,
39+
std::unique_ptr<SSAGraphExecutor>&& underlying_executor);
40+
FeedFetchList Run(const std::vector<std::string>& fetch_tensors) override;
41+
42+
private:
43+
size_t drop_scope_counter_{0};
44+
45+
ExecutionStrategy strategy_;
46+
std::unique_ptr<SSAGraphExecutor> underlying_executor_;
47+
std::vector<Scope*> local_scopes_;
48+
std::vector<VariableInfo> var_infos_;
49+
std::vector<platform::Place> places_;
50+
};
51+
} // namespace details
52+
} // namespace framework
53+
} // namespace paddle

paddle/fluid/framework/details/ssa_graph_executor.cc

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,6 @@
1717
namespace paddle {
1818
namespace framework {
1919
namespace details {
20-
21-
SSAGraphExecutor::SSAGraphExecutor(std::unique_ptr<SSAGraph> &&graph)
22-
: graph_(std::move(graph)) {}
23-
2420
SSAGraphExecutor::~SSAGraphExecutor() {}
2521

2622
} // namespace details

paddle/fluid/framework/details/ssa_graph_executor.h

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,15 +28,11 @@ class SSAGraphExecutor {
2828
DISABLE_COPY_AND_ASSIGN(SSAGraphExecutor);
2929

3030
public:
31-
// Steal graph inside
32-
explicit SSAGraphExecutor(std::unique_ptr<SSAGraph> &&graph);
31+
SSAGraphExecutor() {}
3332

3433
virtual ~SSAGraphExecutor();
3534

3635
virtual FeedFetchList Run(const std::vector<std::string> &fetch_tensors) = 0;
37-
38-
protected:
39-
std::unique_ptr<SSAGraph> graph_;
4036
};
4137
} // namespace details
4238
} // namespace framework

paddle/fluid/framework/details/threaded_ssa_graph_executor.cc

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ ThreadedSSAGraphExecutor::ThreadedSSAGraphExecutor(
2121
const ExecutionStrategy &strategy, const std::vector<Scope *> &local_scopes,
2222
const std::vector<platform::Place> &places,
2323
std::unique_ptr<SSAGraph> &&graph)
24-
: SSAGraphExecutor(std::move(graph)),
24+
: graph_(std::move(graph)),
2525
pool_(strategy.num_threads_ >= 2 ? new ::ThreadPool(strategy.num_threads_)
2626
: nullptr),
2727
local_scopes_(local_scopes),
@@ -189,7 +189,9 @@ void ThreadedSSAGraphExecutor::RunOp(
189189
BlockingQueue<VarHandleBase *> *ready_var_q, details::OpHandleBase *op) {
190190
auto op_run = [ready_var_q, op, this] {
191191
try {
192-
VLOG(10) << op << " " << op->Name() << " : " << op->DebugString();
192+
if (VLOG_IS_ON(10)) {
193+
VLOG(10) << op << " " << op->Name() << " : " << op->DebugString();
194+
}
193195
op->Run(strategy_.use_event_);
194196
VLOG(10) << op << " " << op->Name() << " Done ";
195197
running_ops_--;

paddle/fluid/framework/details/threaded_ssa_graph_executor.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
5151
details::OpHandleBase *op);
5252

5353
private:
54+
std::unique_ptr<SSAGraph> graph_;
5455
std::unique_ptr<::ThreadPool> pool_;
5556
std::vector<Scope *> local_scopes_;
5657
std::vector<platform::Place> places_;

paddle/fluid/framework/parallel_executor.cc

Lines changed: 16 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ limitations under the License. */
2323
#endif
2424

2525
#include "paddle/fluid/framework/details/multi_devices_graph_builder.h"
26+
#include "paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h"
2627
#include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h"
2728
#include "paddle/fluid/platform/profiler.h"
2829

@@ -42,8 +43,6 @@ class ParallelExecutorPrivate {
4243
#ifdef PADDLE_WITH_CUDA
4344
std::unique_ptr<platform::NCCLContextMap> nccl_ctxs_;
4445
#endif
45-
46-
std::vector<std::tuple<std::string, proto::VarType::Type, bool>> var_types_;
4746
bool own_local_scope;
4847
};
4948

@@ -92,9 +91,18 @@ ParallelExecutor::ParallelExecutor(
9291
local_scopes.empty()) { // Is CUDA
9392
BCastParamsToGPUs(bcast_vars);
9493
}
95-
// Startup Program has been run. All local scopes has correct parameters.
94+
// Startup Program has been run. All local scopes has correct parameters.
95+
96+
// Step 2. Create vars in each scope;
97+
std::vector<details::VariableInfo> var_infos;
98+
for (auto *var : main_program.Block(0).AllVars()) {
99+
var_infos.emplace_back();
100+
var_infos.back().name_ = var->Name();
101+
var_infos.back().type_ = var->GetType();
102+
var_infos.back().persistable_ = var->Persistable();
103+
}
96104

97-
// Step 2. Convert main_program to SSA form and dependency graph. Also, insert
105+
// Step 3. Convert main_program to SSA form and dependency graph. Also, insert
98106
// ncclOp
99107
#ifdef PADDLE_WITH_CUDA
100108
details::MultiDevSSAGraphBuilder builder(
@@ -105,16 +113,15 @@ ParallelExecutor::ParallelExecutor(
105113
params, member_->local_scopes_,
106114
build_strategy);
107115
#endif
116+
108117
auto graph = builder.Build(main_program);
109118

110119
member_->executor_.reset(new details::ThreadedSSAGraphExecutor(
111120
exec_strategy, member_->local_scopes_, places, std::move(graph)));
112121

113-
// Step 3. Create vars in each scope;
114-
for (auto *var : main_program.Block(0).AllVars()) {
115-
member_->var_types_.emplace_back(var->Name(), var->GetType(),
116-
var->Persistable());
117-
}
122+
member_->executor_.reset(new details::ScopeBufferedSSAGraphExecutor(
123+
exec_strategy, member_->local_scopes_, std::move(var_infos),
124+
member_->places_, std::move(member_->executor_)));
118125
}
119126

120127
void ParallelExecutor::BCastParamsToGPUs(
@@ -169,42 +176,9 @@ void ParallelExecutor::BCastParamsToGPUs(
169176
void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors,
170177
const std::string &fetched_var_name) {
171178
platform::RecordBlock b(0);
172-
// Create local scopes.
173-
for (auto it = member_->local_scopes_.rbegin();
174-
it != member_->local_scopes_.rend(); ++it) {
175-
auto &scope = *it;
176-
Scope &local_scope = scope->NewScope();
177-
*scope->Var(details::kLocalExecScopeName)->GetMutable<Scope *>() =
178-
&local_scope;
179-
180-
for (auto &name_type_pair : member_->var_types_) {
181-
if (scope->FindVar(std::get<0>(name_type_pair)) != nullptr) {
182-
continue;
183-
}
184-
185-
if (std::get<2>(name_type_pair)) { // Persistable
186-
InitializeVariable(scope->Var(std::get<0>(name_type_pair)),
187-
std::get<1>(name_type_pair));
188-
} else {
189-
InitializeVariable(local_scope.Var(std::get<0>(name_type_pair)),
190-
std::get<1>(name_type_pair));
191-
}
192-
}
193-
}
194-
195179
auto fetch_data = member_->executor_->Run(fetch_tensors);
196180
*member_->global_scope_->Var(fetched_var_name)->GetMutable<FeedFetchList>() =
197181
fetch_data;
198-
199-
// Wait All computational streams
200-
for (auto p : member_->places_) {
201-
platform::DeviceContextPool::Instance().Get(p)->Wait();
202-
}
203-
for (auto &scope : member_->local_scopes_) {
204-
auto &local_scope =
205-
*scope->Var(details::kLocalExecScopeName)->GetMutable<Scope *>();
206-
scope->DeleteScope(local_scope);
207-
}
208182
}
209183

210184
void ParallelExecutor::FeedTensorsIntoLocalScopes(

0 commit comments

Comments
 (0)