Skip to content

Commit 82d2903

Browse files
author
chengduozh
committed
Fix fast ParallelExe bug
test=develop
1 parent dcfb687 commit 82d2903

File tree

4 files changed

+21
-0
lines changed

4 files changed

+21
-0
lines changed

paddle/fluid/framework/details/var_handle.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@ struct VarHandleBase {
4949

5050
void AddOutput(OpHandleBase* out, ir::Node* node) {
5151
if (pending_ops_.find(out) == pending_ops_.end()) {
52+
PADDLE_ENFORCE(out != nullptr, "The output of %s should not be nullptr",
53+
this->Node()->Name());
5254
pending_ops_.insert(out);
5355
node_->outputs.push_back(node);
5456
}

paddle/fluid/framework/parallel_executor.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -299,6 +299,12 @@ void ParallelExecutor::FeedAndSplitTensorIntoLocalScopes(
299299
}
300300

301301
ParallelExecutor::~ParallelExecutor() {
302+
const auto dev_ctxs =
303+
platform::DeviceContextPool::Instance().GetAllDeviceContexts();
304+
for (auto &dev_ctx : dev_ctxs) {
305+
dev_ctx->Wait();
306+
}
307+
302308
if (member_->own_local_scope_) {
303309
for (size_t i = 1; i < member_->local_scopes_.size(); ++i) {
304310
Scope *local_scope = member_->local_scopes_[i];

paddle/fluid/platform/device_context.cc

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,16 @@ platform::DeviceContext* DeviceContextPool::Get(const platform::Place& place) {
3535
return it->second.get();
3636
}
3737

38+
const std::vector<const DeviceContext*>
39+
DeviceContextPool::GetAllDeviceContexts() const {
40+
std::vector<const DeviceContext*> all_device_ctx;
41+
all_device_ctx.reserve(device_contexts_.size());
42+
for (auto& dev_ctx : device_contexts_) {
43+
all_device_ctx.emplace_back(dev_ctx.second.get());
44+
}
45+
return all_device_ctx;
46+
}
47+
3848
DeviceContextPool::DeviceContextPool(
3949
const std::vector<platform::Place>& places) {
4050
PADDLE_ENFORCE_GT(places.size(), 0);

paddle/fluid/platform/device_context.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,9 @@ class DeviceContextPool {
217217
/*! \brief Return handle of single device context. */
218218
platform::DeviceContext* Get(const platform::Place& place);
219219

220+
/*! \brief Return all the device contexts. */
221+
const std::vector<const DeviceContext*> GetAllDeviceContexts() const;
222+
220223
template <typename Place>
221224
const typename DefaultDeviceContextType<Place>::TYPE* GetByPlace(
222225
const Place& place) {

0 commit comments

Comments
 (0)