File tree Expand file tree Collapse file tree 4 files changed +21
-0
lines changed Expand file tree Collapse file tree 4 files changed +21
-0
lines changed Original file line number Diff line number Diff line change @@ -49,6 +49,8 @@ struct VarHandleBase {
49
49
50
50
void AddOutput (OpHandleBase* out, ir::Node* node) {
51
51
if (pending_ops_.find (out) == pending_ops_.end ()) {
52
+ PADDLE_ENFORCE (out != nullptr , " The output of %s should not be nullptr" ,
53
+ this ->Node ()->Name ());
52
54
pending_ops_.insert (out);
53
55
node_->outputs .push_back (node);
54
56
}
Original file line number Diff line number Diff line change @@ -299,6 +299,12 @@ void ParallelExecutor::FeedAndSplitTensorIntoLocalScopes(
299
299
}
300
300
301
301
ParallelExecutor::~ParallelExecutor () {
302
+ const auto dev_ctxs =
303
+ platform::DeviceContextPool::Instance ().GetAllDeviceContexts ();
304
+ for (auto &dev_ctx : dev_ctxs) {
305
+ dev_ctx->Wait ();
306
+ }
307
+
302
308
if (member_->own_local_scope_ ) {
303
309
for (size_t i = 1 ; i < member_->local_scopes_ .size (); ++i) {
304
310
Scope *local_scope = member_->local_scopes_ [i];
Original file line number Diff line number Diff line change @@ -35,6 +35,16 @@ platform::DeviceContext* DeviceContextPool::Get(const platform::Place& place) {
35
35
return it->second .get ();
36
36
}
37
37
38
+ const std::vector<const DeviceContext*>
39
+ DeviceContextPool::GetAllDeviceContexts () const {
40
+ std::vector<const DeviceContext*> all_device_ctx;
41
+ all_device_ctx.reserve (device_contexts_.size ());
42
+ for (auto & dev_ctx : device_contexts_) {
43
+ all_device_ctx.emplace_back (dev_ctx.second .get ());
44
+ }
45
+ return all_device_ctx;
46
+ }
47
+
38
48
DeviceContextPool::DeviceContextPool (
39
49
const std::vector<platform::Place>& places) {
40
50
PADDLE_ENFORCE_GT (places.size (), 0 );
Original file line number Diff line number Diff line change @@ -217,6 +217,9 @@ class DeviceContextPool {
217
217
/* ! \brief Return handle of single device context. */
218
218
platform::DeviceContext* Get (const platform::Place& place);
219
219
220
+ /* ! \brief Return all the device contexts. */
221
+ const std::vector<const DeviceContext*> GetAllDeviceContexts () const ;
222
+
220
223
template <typename Place>
221
224
const typename DefaultDeviceContextType<Place>::TYPE* GetByPlace (
222
225
const Place& place) {
You can’t perform that action at this time.
0 commit comments