Skip to content

Commit dc863aa

Browse files
committed
Add kids exists detection in Scope
1 parent 681514e commit dc863aa

16 files changed

+60
-79
lines changed

paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.cc

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,7 @@ namespace framework {
2222
namespace details {
2323

2424
FastThreadedSSAGraphExecutor::FastThreadedSSAGraphExecutor(
25-
const ExecutionStrategy &strategy,
26-
const std::vector<std::shared_ptr<Scope>> &local_scopes,
25+
const ExecutionStrategy &strategy, const std::vector<Scope *> &local_scopes,
2726
const std::vector<platform::Place> &places,
2827
std::unique_ptr<ir::Graph> &&graph)
2928
: strategy_(strategy),

paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.h

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,17 +29,16 @@ namespace details {
2929
class OpHandleBase;
3030
class FastThreadedSSAGraphExecutor : public SSAGraphExecutor {
3131
public:
32-
FastThreadedSSAGraphExecutor(
33-
const ExecutionStrategy &strategy,
34-
const std::vector<std::shared_ptr<Scope>> &local_scopes,
35-
const std::vector<platform::Place> &places,
36-
std::unique_ptr<ir::Graph> &&graph);
32+
FastThreadedSSAGraphExecutor(const ExecutionStrategy &strategy,
33+
const std::vector<Scope *> &local_scopes,
34+
const std::vector<platform::Place> &places,
35+
std::unique_ptr<ir::Graph> &&graph);
3736
FeedFetchList Run(const std::vector<std::string> &fetch_tensors) override;
3837
const ir::Graph &Graph() const override;
3938

4039
private:
4140
ExecutionStrategy strategy_;
42-
std::vector<std::shared_ptr<Scope>> local_scopes_;
41+
std::vector<Scope *> local_scopes_;
4342
std::vector<platform::Place> places_;
4443
std::unique_ptr<ir::Graph> graph_;
4544

paddle/fluid/framework/details/fetch_op_handle.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ namespace framework {
2222
namespace details {
2323

2424
FetchOpHandle::FetchOpHandle(ir::Node *node, FeedFetchList *data, size_t offset,
25-
std::vector<std::shared_ptr<Scope>> *local_scopes)
25+
std::vector<Scope *> *local_scopes)
2626
: OpHandleBase(node),
2727
data_(data),
2828
offset_(offset),

paddle/fluid/framework/details/fetch_op_handle.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ namespace details {
2929
struct FetchOpHandle : public OpHandleBase {
3030
public:
3131
FetchOpHandle(ir::Node *node, FeedFetchList *data, size_t offset,
32-
std::vector<std::shared_ptr<Scope>> *local_scopes);
32+
std::vector<Scope *> *local_scopes);
3333

3434
~FetchOpHandle();
3535

@@ -47,7 +47,7 @@ struct FetchOpHandle : public OpHandleBase {
4747
private:
4848
FeedFetchList *data_;
4949
size_t offset_;
50-
std::vector<std::shared_ptr<Scope>> *local_scopes_;
50+
std::vector<Scope *> *local_scopes_;
5151
std::vector<LoDTensor> tensors_;
5252
};
5353

paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.cc

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,7 @@ namespace paddle {
2323
namespace framework {
2424
namespace details {
2525
ScopeBufferedSSAGraphExecutor::ScopeBufferedSSAGraphExecutor(
26-
ExecutionStrategy strategy,
27-
std::vector<std::shared_ptr<Scope>> local_scopes,
26+
ExecutionStrategy strategy, std::vector<Scope *> local_scopes,
2827
std::vector<VariableInfo> var_infos, std::vector<platform::Place> places,
2928
std::unique_ptr<SSAGraphExecutor> &&underlying_executor)
3029
: strategy_(std::move(strategy)),

paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,7 @@ struct VariableInfo {
3737
class ScopeBufferedSSAGraphExecutor : public SSAGraphExecutor {
3838
public:
3939
ScopeBufferedSSAGraphExecutor(
40-
ExecutionStrategy strategy,
41-
std::vector<std::shared_ptr<Scope>> local_scopes,
40+
ExecutionStrategy strategy, std::vector<Scope*> local_scopes,
4241
std::vector<VariableInfo> var_infos, std::vector<platform::Place> places,
4342
std::unique_ptr<SSAGraphExecutor>&& underlying_executor);
4443

@@ -53,7 +52,7 @@ class ScopeBufferedSSAGraphExecutor : public SSAGraphExecutor {
5352

5453
ExecutionStrategy strategy_;
5554
std::unique_ptr<SSAGraphExecutor> underlying_executor_;
56-
std::vector<std::shared_ptr<Scope>> local_scopes_;
55+
std::vector<Scope*> local_scopes_;
5756
std::vector<VariableInfo> var_infos_;
5857
std::vector<platform::Place> places_;
5958
};

paddle/fluid/framework/details/threaded_ssa_graph_executor.cc

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,7 @@ namespace paddle {
2121
namespace framework {
2222
namespace details {
2323
ThreadedSSAGraphExecutor::ThreadedSSAGraphExecutor(
24-
const ExecutionStrategy &strategy,
25-
const std::vector<std::shared_ptr<Scope>> &local_scopes,
24+
const ExecutionStrategy &strategy, const std::vector<Scope *> &local_scopes,
2625
const std::vector<platform::Place> &places,
2726
std::unique_ptr<ir::Graph> &&graph)
2827
: graph_(std::move(graph)),

paddle/fluid/framework/details/threaded_ssa_graph_executor.h

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,10 @@ namespace details {
3838

3939
class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
4040
public:
41-
ThreadedSSAGraphExecutor(
42-
const ExecutionStrategy &strategy,
43-
const std::vector<std::shared_ptr<Scope>> &local_scopes,
44-
const std::vector<platform::Place> &places,
45-
std::unique_ptr<ir::Graph> &&graph);
41+
ThreadedSSAGraphExecutor(const ExecutionStrategy &strategy,
42+
const std::vector<Scope *> &local_scopes,
43+
const std::vector<platform::Place> &places,
44+
std::unique_ptr<ir::Graph> &&graph);
4645

4746
const ir::Graph &Graph() const override { return *graph_; }
4847
// Run a SSAGraph by a thread pool
@@ -58,7 +57,7 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
5857
private:
5958
std::unique_ptr<ir::Graph> graph_;
6059
std::unique_ptr<::ThreadPool> pool_;
61-
std::vector<std::shared_ptr<Scope>> local_scopes_;
60+
std::vector<Scope *> local_scopes_;
6261
std::vector<platform::Place> places_;
6362
platform::DeviceContextPool fetch_ctxs_;
6463
ExceptionHolder exception_holder_;

paddle/fluid/framework/parallel_executor.cc

Lines changed: 14 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,7 @@ std::unique_ptr<ir::Graph> ApplyParallelExecutorPass(
3939
const ProgramDesc &main_program, const std::vector<platform::Place> &places,
4040
const std::string &loss_var_name,
4141
const std::unordered_set<std::string> &param_names,
42-
const std::vector<std::shared_ptr<Scope>> &local_scopes,
43-
const bool use_cuda,
42+
const std::vector<Scope *> &local_scopes, const bool use_cuda,
4443
#ifdef PADDLE_WITH_CUDA
4544
const BuildStrategy &strategy, platform::NCCLContextMap *nccl_ctxs) {
4645
#else
@@ -67,8 +66,8 @@ std::unique_ptr<ir::Graph> ApplyParallelExecutorPass(
6766
&loss_var_name);
6867
multi_devices_pass->SetNotOwned<const std::unordered_set<std::string>>(
6968
"params", &param_names);
70-
multi_devices_pass->SetNotOwned<const std::vector<std::shared_ptr<Scope>>>(
71-
"local_scopes", &local_scopes);
69+
multi_devices_pass->SetNotOwned<const std::vector<Scope *>>("local_scopes",
70+
&local_scopes);
7271
multi_devices_pass->SetNotOwned<const BuildStrategy>("strategy", &strategy);
7372

7473
#ifdef PADDLE_WITH_CUDA
@@ -101,8 +100,8 @@ class ParallelExecutorPrivate {
101100
: places_(places) {}
102101

103102
std::vector<platform::Place> places_;
104-
std::vector<std::shared_ptr<Scope>> local_scopes_;
105-
std::shared_ptr<Scope> global_scope_;
103+
std::vector<Scope *> local_scopes_;
104+
Scope *global_scope_;
106105
std::unique_ptr<details::SSAGraphExecutor> executor_;
107106

108107
#ifdef PADDLE_WITH_CUDA
@@ -113,7 +112,7 @@ class ParallelExecutorPrivate {
113112
bool use_all_reduce_;
114113
};
115114

116-
std::vector<std::shared_ptr<Scope>> &ParallelExecutor::GetLocalScopes() {
115+
std::vector<Scope *> &ParallelExecutor::GetLocalScopes() {
117116
return member_->local_scopes_;
118117
}
119118

@@ -122,8 +121,7 @@ ParallelExecutor::ParallelExecutor(
122121
const std::unordered_set<std::string> &params,
123122
const std::unordered_set<std::string> &bcast_vars,
124123
const ProgramDesc &main_program, const std::string &loss_var_name,
125-
const std::shared_ptr<Scope> &scope,
126-
const std::vector<std::shared_ptr<Scope>> &local_scopes,
124+
Scope *scope, const std::vector<Scope *> &local_scopes,
127125
const ExecutionStrategy &exec_strategy, const BuildStrategy &build_strategy,
128126
size_t num_trainers, size_t trainer_id)
129127
: member_(new ParallelExecutorPrivate(places)) {
@@ -144,13 +142,13 @@ ParallelExecutor::ParallelExecutor(
144142
member_->own_local_scope_ = true;
145143
member_->local_scopes_.emplace_back(member_->global_scope_);
146144
for (size_t i = 1; i < member_->places_.size(); ++i) {
147-
member_->local_scopes_.emplace_back(scope->NewSharedScope());
145+
member_->local_scopes_.emplace_back(&scope->NewScope());
148146
}
149147
} else {
150148
member_->own_local_scope_ = false;
151149
PADDLE_ENFORCE_EQ(member_->places_.size(), local_scopes.size());
152150
for (size_t i = 0; i < member_->places_.size(); ++i) {
153-
member_->local_scopes_.emplace_back(local_scopes[i]->NewSharedScope());
151+
member_->local_scopes_.emplace_back(&local_scopes[i]->NewScope());
154152
}
155153
}
156154

@@ -323,7 +321,7 @@ void ParallelExecutor::FeedTensorsIntoLocalScopes(
323321

324322
for (size_t i = 0; i < tensors.size(); ++i) {
325323
auto &map = tensors[i];
326-
auto &scope = member_->local_scopes_[i];
324+
auto *scope = member_->local_scopes_[i];
327325
for (auto &pair : map) {
328326
auto *trg = scope->Var(pair.first)->GetMutable<LoDTensor>();
329327
trg->ShareDataWith(pair.second);
@@ -353,15 +351,11 @@ void ParallelExecutor::FeedAndSplitTensorIntoLocalScopes(
353351

354352
ParallelExecutor::~ParallelExecutor() {
355353
if (member_->own_local_scope_) {
356-
std::vector<Scope *> local_scopes_ptrs;
357-
local_scopes_ptrs.reserve(member_->local_scopes_.size());
358354
for (size_t i = 1; i < member_->local_scopes_.size(); ++i) {
359-
local_scopes_ptrs.emplace_back(member_->local_scopes_[i].get());
360-
member_->local_scopes_[i].reset();
361-
}
362-
363-
for (size_t i = 0; i != local_scopes_ptrs.size(); ++i) {
364-
member_->global_scope_->DeleteScope(local_scopes_ptrs[i]);
355+
Scope *local_scope = member_->local_scopes_[i];
356+
if (member_->global_scope_->HasKid(local_scope)) {
357+
member_->global_scope_->DeleteScope(local_scope);
358+
}
365359
}
366360
}
367361
}

paddle/fluid/framework/parallel_executor.h

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -39,20 +39,19 @@ class ParallelExecutor {
3939
DISABLE_COPY_AND_ASSIGN(ParallelExecutor);
4040

4141
public:
42-
explicit ParallelExecutor(
43-
const std::vector<platform::Place> &places,
44-
const std::unordered_set<std::string> &params,
45-
const std::unordered_set<std::string> &bcast_vars,
46-
const ProgramDesc &main_program, const std::string &loss_var_name,
47-
const std::shared_ptr<Scope> &scope,
48-
const std::vector<std::shared_ptr<Scope>> &local_scopes,
49-
const ExecutionStrategy &exec_strategy,
50-
const BuildStrategy &build_strategy, size_t num_trainers = 1,
51-
size_t trainer_id = 0);
42+
explicit ParallelExecutor(const std::vector<platform::Place> &places,
43+
const std::unordered_set<std::string> &params,
44+
const std::unordered_set<std::string> &bcast_vars,
45+
const ProgramDesc &main_program,
46+
const std::string &loss_var_name, Scope *scope,
47+
const std::vector<Scope *> &local_scopes,
48+
const ExecutionStrategy &exec_strategy,
49+
const BuildStrategy &build_strategy,
50+
size_t num_trainers = 1, size_t trainer_id = 0);
5251

5352
~ParallelExecutor();
5453

55-
std::vector<std::shared_ptr<Scope>> &GetLocalScopes();
54+
std::vector<Scope *> &GetLocalScopes();
5655

5756
/**
5857
* Feed tensors to local scopes. The size of tensors should be equal to the

0 commit comments

Comments
 (0)