Skip to content

Commit 681514e

Browse files
committed
Make all scope pointer to shared
1 parent ce24a92 commit 681514e

12 files changed

+63
-44
lines changed

paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@ namespace framework {
2222
namespace details {
2323

2424
FastThreadedSSAGraphExecutor::FastThreadedSSAGraphExecutor(
25-
const ExecutionStrategy &strategy, const std::vector<Scope *> &local_scopes,
25+
const ExecutionStrategy &strategy,
26+
const std::vector<std::shared_ptr<Scope>> &local_scopes,
2627
const std::vector<platform::Place> &places,
2728
std::unique_ptr<ir::Graph> &&graph)
2829
: strategy_(strategy),

paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.h

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,16 +29,17 @@ namespace details {
2929
class OpHandleBase;
3030
class FastThreadedSSAGraphExecutor : public SSAGraphExecutor {
3131
public:
32-
FastThreadedSSAGraphExecutor(const ExecutionStrategy &strategy,
33-
const std::vector<Scope *> &local_scopes,
34-
const std::vector<platform::Place> &places,
35-
std::unique_ptr<ir::Graph> &&graph);
32+
FastThreadedSSAGraphExecutor(
33+
const ExecutionStrategy &strategy,
34+
const std::vector<std::shared_ptr<Scope>> &local_scopes,
35+
const std::vector<platform::Place> &places,
36+
std::unique_ptr<ir::Graph> &&graph);
3637
FeedFetchList Run(const std::vector<std::string> &fetch_tensors) override;
3738
const ir::Graph &Graph() const override;
3839

3940
private:
4041
ExecutionStrategy strategy_;
41-
std::vector<Scope *> local_scopes_;
42+
std::vector<std::shared_ptr<Scope>> local_scopes_;
4243
std::vector<platform::Place> places_;
4344
std::unique_ptr<ir::Graph> graph_;
4445

paddle/fluid/framework/details/fetch_op_handle.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ namespace framework {
2222
namespace details {
2323

2424
FetchOpHandle::FetchOpHandle(ir::Node *node, FeedFetchList *data, size_t offset,
25-
std::vector<Scope *> *local_scopes)
25+
std::vector<std::shared_ptr<Scope>> *local_scopes)
2626
: OpHandleBase(node),
2727
data_(data),
2828
offset_(offset),

paddle/fluid/framework/details/fetch_op_handle.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ namespace details {
2929
struct FetchOpHandle : public OpHandleBase {
3030
public:
3131
FetchOpHandle(ir::Node *node, FeedFetchList *data, size_t offset,
32-
std::vector<Scope *> *local_scopes);
32+
std::vector<std::shared_ptr<Scope>> *local_scopes);
3333

3434
~FetchOpHandle();
3535

@@ -47,7 +47,7 @@ struct FetchOpHandle : public OpHandleBase {
4747
private:
4848
FeedFetchList *data_;
4949
size_t offset_;
50-
std::vector<Scope *> *local_scopes_;
50+
std::vector<std::shared_ptr<Scope>> *local_scopes_;
5151
std::vector<LoDTensor> tensors_;
5252
};
5353

paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@ namespace paddle {
2323
namespace framework {
2424
namespace details {
2525
ScopeBufferedSSAGraphExecutor::ScopeBufferedSSAGraphExecutor(
26-
ExecutionStrategy strategy, std::vector<Scope *> local_scopes,
26+
ExecutionStrategy strategy,
27+
std::vector<std::shared_ptr<Scope>> local_scopes,
2728
std::vector<VariableInfo> var_infos, std::vector<platform::Place> places,
2829
std::unique_ptr<SSAGraphExecutor> &&underlying_executor)
2930
: strategy_(std::move(strategy)),

paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,8 @@ struct VariableInfo {
3737
class ScopeBufferedSSAGraphExecutor : public SSAGraphExecutor {
3838
public:
3939
ScopeBufferedSSAGraphExecutor(
40-
ExecutionStrategy strategy, std::vector<Scope*> local_scopes,
40+
ExecutionStrategy strategy,
41+
std::vector<std::shared_ptr<Scope>> local_scopes,
4142
std::vector<VariableInfo> var_infos, std::vector<platform::Place> places,
4243
std::unique_ptr<SSAGraphExecutor>&& underlying_executor);
4344

@@ -52,7 +53,7 @@ class ScopeBufferedSSAGraphExecutor : public SSAGraphExecutor {
5253

5354
ExecutionStrategy strategy_;
5455
std::unique_ptr<SSAGraphExecutor> underlying_executor_;
55-
std::vector<Scope*> local_scopes_;
56+
std::vector<std::shared_ptr<Scope>> local_scopes_;
5657
std::vector<VariableInfo> var_infos_;
5758
std::vector<platform::Place> places_;
5859
};

paddle/fluid/framework/details/threaded_ssa_graph_executor.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@ namespace paddle {
2121
namespace framework {
2222
namespace details {
2323
ThreadedSSAGraphExecutor::ThreadedSSAGraphExecutor(
24-
const ExecutionStrategy &strategy, const std::vector<Scope *> &local_scopes,
24+
const ExecutionStrategy &strategy,
25+
const std::vector<std::shared_ptr<Scope>> &local_scopes,
2526
const std::vector<platform::Place> &places,
2627
std::unique_ptr<ir::Graph> &&graph)
2728
: graph_(std::move(graph)),

paddle/fluid/framework/details/threaded_ssa_graph_executor.h

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,11 @@ namespace details {
3838

3939
class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
4040
public:
41-
ThreadedSSAGraphExecutor(const ExecutionStrategy &strategy,
42-
const std::vector<Scope *> &local_scopes,
43-
const std::vector<platform::Place> &places,
44-
std::unique_ptr<ir::Graph> &&graph);
41+
ThreadedSSAGraphExecutor(
42+
const ExecutionStrategy &strategy,
43+
const std::vector<std::shared_ptr<Scope>> &local_scopes,
44+
const std::vector<platform::Place> &places,
45+
std::unique_ptr<ir::Graph> &&graph);
4546

4647
const ir::Graph &Graph() const override { return *graph_; }
4748
// Run a SSAGraph by a thread pool
@@ -57,7 +58,7 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
5758
private:
5859
std::unique_ptr<ir::Graph> graph_;
5960
std::unique_ptr<::ThreadPool> pool_;
60-
std::vector<Scope *> local_scopes_;
61+
std::vector<std::shared_ptr<Scope>> local_scopes_;
6162
std::vector<platform::Place> places_;
6263
platform::DeviceContextPool fetch_ctxs_;
6364
ExceptionHolder exception_holder_;

paddle/fluid/framework/parallel_executor.cc

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,8 @@ std::unique_ptr<ir::Graph> ApplyParallelExecutorPass(
3939
const ProgramDesc &main_program, const std::vector<platform::Place> &places,
4040
const std::string &loss_var_name,
4141
const std::unordered_set<std::string> &param_names,
42-
const std::vector<Scope *> &local_scopes, const bool use_cuda,
42+
const std::vector<std::shared_ptr<Scope>> &local_scopes,
43+
const bool use_cuda,
4344
#ifdef PADDLE_WITH_CUDA
4445
const BuildStrategy &strategy, platform::NCCLContextMap *nccl_ctxs) {
4546
#else
@@ -66,8 +67,8 @@ std::unique_ptr<ir::Graph> ApplyParallelExecutorPass(
6667
&loss_var_name);
6768
multi_devices_pass->SetNotOwned<const std::unordered_set<std::string>>(
6869
"params", &param_names);
69-
multi_devices_pass->SetNotOwned<const std::vector<Scope *>>("local_scopes",
70-
&local_scopes);
70+
multi_devices_pass->SetNotOwned<const std::vector<std::shared_ptr<Scope>>>(
71+
"local_scopes", &local_scopes);
7172
multi_devices_pass->SetNotOwned<const BuildStrategy>("strategy", &strategy);
7273

7374
#ifdef PADDLE_WITH_CUDA
@@ -100,8 +101,8 @@ class ParallelExecutorPrivate {
100101
: places_(places) {}
101102

102103
std::vector<platform::Place> places_;
103-
std::vector<Scope *> local_scopes_;
104-
Scope *global_scope_;
104+
std::vector<std::shared_ptr<Scope>> local_scopes_;
105+
std::shared_ptr<Scope> global_scope_;
105106
std::unique_ptr<details::SSAGraphExecutor> executor_;
106107

107108
#ifdef PADDLE_WITH_CUDA
@@ -112,7 +113,7 @@ class ParallelExecutorPrivate {
112113
bool use_all_reduce_;
113114
};
114115

115-
std::vector<Scope *> &ParallelExecutor::GetLocalScopes() {
116+
std::vector<std::shared_ptr<Scope>> &ParallelExecutor::GetLocalScopes() {
116117
return member_->local_scopes_;
117118
}
118119

@@ -121,7 +122,8 @@ ParallelExecutor::ParallelExecutor(
121122
const std::unordered_set<std::string> &params,
122123
const std::unordered_set<std::string> &bcast_vars,
123124
const ProgramDesc &main_program, const std::string &loss_var_name,
124-
Scope *scope, const std::vector<Scope *> &local_scopes,
125+
const std::shared_ptr<Scope> &scope,
126+
const std::vector<std::shared_ptr<Scope>> &local_scopes,
125127
const ExecutionStrategy &exec_strategy, const BuildStrategy &build_strategy,
126128
size_t num_trainers, size_t trainer_id)
127129
: member_(new ParallelExecutorPrivate(places)) {
@@ -142,13 +144,13 @@ ParallelExecutor::ParallelExecutor(
142144
member_->own_local_scope_ = true;
143145
member_->local_scopes_.emplace_back(member_->global_scope_);
144146
for (size_t i = 1; i < member_->places_.size(); ++i) {
145-
member_->local_scopes_.emplace_back(&scope->NewScope());
147+
member_->local_scopes_.emplace_back(scope->NewSharedScope());
146148
}
147149
} else {
148150
member_->own_local_scope_ = false;
149151
PADDLE_ENFORCE_EQ(member_->places_.size(), local_scopes.size());
150152
for (size_t i = 0; i < member_->places_.size(); ++i) {
151-
member_->local_scopes_.emplace_back(&local_scopes[i]->NewScope());
153+
member_->local_scopes_.emplace_back(local_scopes[i]->NewSharedScope());
152154
}
153155
}
154156

@@ -321,7 +323,7 @@ void ParallelExecutor::FeedTensorsIntoLocalScopes(
321323

322324
for (size_t i = 0; i < tensors.size(); ++i) {
323325
auto &map = tensors[i];
324-
auto *scope = member_->local_scopes_[i];
326+
auto &scope = member_->local_scopes_[i];
325327
for (auto &pair : map) {
326328
auto *trg = scope->Var(pair.first)->GetMutable<LoDTensor>();
327329
trg->ShareDataWith(pair.second);
@@ -351,8 +353,15 @@ void ParallelExecutor::FeedAndSplitTensorIntoLocalScopes(
351353

352354
ParallelExecutor::~ParallelExecutor() {
353355
if (member_->own_local_scope_) {
356+
std::vector<Scope *> local_scopes_ptrs;
357+
local_scopes_ptrs.reserve(member_->local_scopes_.size());
354358
for (size_t i = 1; i < member_->local_scopes_.size(); ++i) {
355-
member_->global_scope_->DeleteScope(member_->local_scopes_[i]);
359+
local_scopes_ptrs.emplace_back(member_->local_scopes_[i].get());
360+
member_->local_scopes_[i].reset();
361+
}
362+
363+
for (size_t i = 0; i != local_scopes_ptrs.size(); ++i) {
364+
member_->global_scope_->DeleteScope(local_scopes_ptrs[i]);
356365
}
357366
}
358367
}

paddle/fluid/framework/parallel_executor.h

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -39,19 +39,20 @@ class ParallelExecutor {
3939
DISABLE_COPY_AND_ASSIGN(ParallelExecutor);
4040

4141
public:
42-
explicit ParallelExecutor(const std::vector<platform::Place> &places,
43-
const std::unordered_set<std::string> &params,
44-
const std::unordered_set<std::string> &bcast_vars,
45-
const ProgramDesc &main_program,
46-
const std::string &loss_var_name, Scope *scope,
47-
const std::vector<Scope *> &local_scopes,
48-
const ExecutionStrategy &exec_strategy,
49-
const BuildStrategy &build_strategy,
50-
size_t num_trainers = 1, size_t trainer_id = 0);
42+
explicit ParallelExecutor(
43+
const std::vector<platform::Place> &places,
44+
const std::unordered_set<std::string> &params,
45+
const std::unordered_set<std::string> &bcast_vars,
46+
const ProgramDesc &main_program, const std::string &loss_var_name,
47+
const std::shared_ptr<Scope> &scope,
48+
const std::vector<std::shared_ptr<Scope>> &local_scopes,
49+
const ExecutionStrategy &exec_strategy,
50+
const BuildStrategy &build_strategy, size_t num_trainers = 1,
51+
size_t trainer_id = 0);
5152

5253
~ParallelExecutor();
5354

54-
std::vector<Scope *> &GetLocalScopes();
55+
std::vector<std::shared_ptr<Scope>> &GetLocalScopes();
5556

5657
/**
5758
* Feed tensors to local scopes. The size of tensors should be equal to the

0 commit comments

Comments
 (0)