Skip to content

Commit ad73b33

Browse files
reyoungqingqing01
authored andcommitted
Eagerly drop local scope in iteration (#9838)
* Eagerly drop local scope in iteration * Correct create var * Fix typo * Debug
1 parent 8d4d6ea commit ad73b33

File tree

7 files changed

+54
-44
lines changed

7 files changed

+54
-44
lines changed

paddle/fluid/framework/details/computation_op_handle.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414

1515
#include "paddle/fluid/framework/details/computation_op_handle.h"
1616

17+
#include <string>
18+
1719
namespace paddle {
1820
namespace framework {
1921
namespace details {
@@ -33,7 +35,7 @@ void ComputationOpHandle::RunImpl() {
3335
}
3436
}
3537

36-
op_->Run(*scope_->FindVar("@TMP_SCOPE@")->Get<Scope *>(), place_);
38+
op_->Run(*scope_->FindVar(kLocalExecScopeName)->Get<Scope *>(), place_);
3739
}
3840

3941
std::string ComputationOpHandle::Name() const { return op_->Type(); }

paddle/fluid/framework/details/fetch_op_handle.cc

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@
1414

1515
#include "paddle/fluid/framework/details/fetch_op_handle.h"
1616

17+
#include <string>
18+
#include <vector>
19+
1720
namespace paddle {
1821
namespace framework {
1922
namespace details {
@@ -57,7 +60,10 @@ void FetchOpHandle::RunImpl() {
5760

5861
for (size_t i = 0; i < scopes.size(); ++i) {
5962
auto &scope = scopes[i];
60-
auto &t = scope->FindVar(var_name)->Get<framework::LoDTensor>();
63+
auto &t = scope->FindVar(kLocalExecScopeName)
64+
->Get<Scope *>()
65+
->FindVar(var_name)
66+
->Get<framework::LoDTensor>();
6167
if (platform::is_gpu_place(var->place_)) {
6268
#ifdef PADDLE_WITH_CUDA
6369
TensorCopy(t, cpu, *dev_ctxes_[t.place()], &tensors_[i]);

paddle/fluid/framework/details/op_handle_base.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ namespace paddle {
2424
namespace framework {
2525
namespace details {
2626

27+
constexpr char kLocalExecScopeName[] = "@LCOAL_SCOPE@";
28+
2729
class OpHandleBase {
2830
private:
2931
DISABLE_COPY_AND_ASSIGN(OpHandleBase);

paddle/fluid/framework/details/ssa_graph_executor.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,15 @@
1515
#pragma once
1616

1717
#include <memory>
18+
#include <string>
19+
#include <vector>
20+
1821
#include "paddle/fluid/framework/details/ssa_graph.h"
1922
#include "paddle/fluid/framework/feed_fetch_type.h"
2023

2124
namespace paddle {
2225
namespace framework {
2326
namespace details {
24-
2527
class SSAGraphExecutor {
2628
DISABLE_COPY_AND_ASSIGN(SSAGraphExecutor);
2729

paddle/fluid/framework/details/threaded_ssa_graph_executor.cc

Lines changed: 0 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -136,12 +136,6 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
136136
ready_ops.clear();
137137
};
138138

139-
// Create local scopes.
140-
for (auto &scope : local_scopes_) {
141-
auto &local_scope = scope->NewScope();
142-
*scope->Var("@TMP_SCOPE@")->GetMutable<Scope *>() = &local_scope;
143-
}
144-
145139
// Step 3. Execution
146140
while (!pending_vars.empty() || !ready_ops.empty() || !delayed_ops.empty()) {
147141
// 1. Run All Ready ops
@@ -189,34 +183,10 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
189183
PADDLE_ENFORCE(ready_ops.empty());
190184
PADDLE_ENFORCE(delayed_ops.empty());
191185
PADDLE_ENFORCE(blocked_by_delayed_ops.empty());
192-
++computation_count_;
193-
194-
auto sync_computation = [&] {
195-
computation_count_ = 0;
196-
// Wait All computational streams
197-
for (auto p : this->places_) {
198-
platform::DeviceContextPool::Instance().Get(p)->Wait();
199-
}
200-
for (auto &scope : local_scopes_) {
201-
scope->DropKids();
202-
}
203-
};
204186

205187
// Wait FetchOps.
206188
if (!fetch_ops.empty()) {
207189
fetch_ops.clear();
208-
sync_computation();
209-
}
210-
211-
if (computation_count_ == max_async_computation) {
212-
sync_computation();
213-
}
214-
215-
// NOTE: the temp scope can be dropped lazily if needed.
216-
// Drop tmp scopes;
217-
for (auto &scope : local_scopes_) {
218-
auto &kid = *scope->Var("@TMP_SCOPE@")->GetMutable<Scope *>();
219-
kid = nullptr;
220190
}
221191

222192
return fetch_data;

paddle/fluid/framework/details/threaded_ssa_graph_executor.h

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -99,9 +99,6 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
9999
std::unique_ptr<platform::EnforceNotMet> exception_;
100100
std::atomic<int> running_ops_;
101101
bool allow_op_delay_;
102-
103-
size_t computation_count_{0};
104-
size_t max_async_computation{100};
105102
};
106103

107104
} // namespace details

paddle/fluid/framework/parallel_executor.cc

Lines changed: 39 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ limitations under the License. */
1515
#include "paddle/fluid/framework/parallel_executor.h"
1616

1717
#include <string>
18+
#include <tuple>
1819
#include <vector>
1920

2021
#ifdef PADDLE_WITH_CUDA
@@ -41,6 +42,8 @@ class ParallelExecutorPrivate {
4142
#ifdef PADDLE_WITH_CUDA
4243
std::unique_ptr<platform::NCCLContextMap> nccl_ctxs_;
4344
#endif
45+
46+
std::vector<std::tuple<std::string, proto::VarType::Type, bool>> var_types_;
4447
};
4548

4649
std::vector<Scope *> &ParallelExecutor::GetLocalScopes() {
@@ -97,14 +100,9 @@ ParallelExecutor::ParallelExecutor(
97100
allow_op_delay));
98101

99102
// Step 3. Create vars in each scope;
100-
for (auto *scope : member_->local_scopes_) {
101-
for (auto *var : main_program.Block(0).AllVars()) {
102-
if (scope->FindVar(var->Name()) != nullptr) {
103-
continue;
104-
}
105-
106-
InitializeVariable(scope->Var(var->Name()), var->GetType());
107-
}
103+
for (auto *var : main_program.Block(0).AllVars()) {
104+
member_->var_types_.emplace_back(var->Name(), var->GetType(),
105+
var->Persistable());
108106
}
109107
}
110108

@@ -163,9 +161,42 @@ void ParallelExecutor::Run(
163161
const std::unordered_map<std::string, LoDTensor> &feed_tensors) {
164162
platform::RecordBlock b(0);
165163
SplitTensorToPlaces(feed_tensors);
164+
165+
// Create local scopes.
166+
for (auto &scope : member_->local_scopes_) {
167+
Scope &local_scope = scope->NewScope();
168+
*scope->Var(details::kLocalExecScopeName)->GetMutable<Scope *>() =
169+
&local_scope;
170+
171+
for (auto &name_type_pair : member_->var_types_) {
172+
if (scope->FindVar(std::get<0>(name_type_pair)) != nullptr) {
173+
continue;
174+
}
175+
176+
if (std::get<2>(name_type_pair)) { // Persistable
177+
InitializeVariable(scope->Var(std::get<0>(name_type_pair)),
178+
std::get<1>(name_type_pair));
179+
} else {
180+
InitializeVariable(scope->Var(std::get<0>(name_type_pair)),
181+
std::get<1>(name_type_pair));
182+
}
183+
}
184+
}
185+
166186
auto fetch_data = member_->executor_->Run(fetch_tensors);
167187
*member_->global_scope_->Var(fetched_var_name)->GetMutable<FeedFetchList>() =
168188
fetch_data;
189+
190+
// Wait All computational streams
191+
for (auto p : member_->places_) {
192+
platform::DeviceContextPool::Instance().Get(p)->Wait();
193+
}
194+
for (auto &scope : member_->local_scopes_) {
195+
auto &local_scope =
196+
*scope->Var(details::kLocalExecScopeName)->GetMutable<Scope *>();
197+
scope->DeleteScope(local_scope);
198+
local_scope = nullptr;
199+
}
169200
}
170201

171202
void ParallelExecutor::SplitTensorToPlaces(

0 commit comments

Comments
 (0)