Skip to content

Commit 90d9e5a

Browse files
authored
feat(platform): lazy initialization of devicecontext in pool (#14067)
* feat(platform): lazy initialization of devicecontext in pool Use std::async(deferer, []{...}) to lazy initialize DeviceContext in Pool test=develop * Add future includes test=develop
1 parent ab4351f commit 90d9e5a

File tree

3 files changed

+22
-27
lines changed

3 files changed

+22
-27
lines changed

paddle/fluid/framework/parallel_executor.cc

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -303,10 +303,8 @@ void ParallelExecutor::FeedAndSplitTensorIntoLocalScopes(
303303
}
304304

305305
ParallelExecutor::~ParallelExecutor() {
306-
const auto dev_ctxs =
307-
platform::DeviceContextPool::Instance().GetAllDeviceContexts();
308-
for (auto &dev_ctx : dev_ctxs) {
309-
dev_ctx->Wait();
306+
for (auto &p : member_->places_) {
307+
platform::DeviceContextPool::Instance().Get(p)->Wait();
310308
}
311309

312310
if (member_->own_local_scope_) {

paddle/fluid/platform/device_context.cc

Lines changed: 17 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -32,23 +32,25 @@ platform::DeviceContext* DeviceContextPool::Get(const platform::Place& place) {
3232
"'Place' is not supported, Please re-compile with WITH_GPU "
3333
"option");
3434
}
35-
return it->second.get();
35+
return it->second.get().get();
3636
}
3737

38-
const std::vector<const DeviceContext*>
39-
DeviceContextPool::GetAllDeviceContexts() const {
40-
std::vector<const DeviceContext*> all_device_ctx;
41-
all_device_ctx.reserve(device_contexts_.size());
42-
for (auto& dev_ctx : device_contexts_) {
43-
all_device_ctx.emplace_back(dev_ctx.second.get());
44-
}
45-
return all_device_ctx;
38+
template <typename DevCtx, typename PlaceType>
39+
inline void EmplaceDeviceContext(
40+
std::map<Place, std::shared_future<std::unique_ptr<DeviceContext>>>*
41+
map_ptr,
42+
platform::Place p) {
43+
using PtrType = std::unique_ptr<DeviceContext>;
44+
map_ptr->emplace(p, std::async(std::launch::deferred, [=] {
45+
// lazy evaluation. i.e., only create device context at
46+
// first `Get`
47+
return PtrType(new DevCtx(boost::get<PlaceType>(p)));
48+
}));
4649
}
4750

4851
DeviceContextPool::DeviceContextPool(
4952
const std::vector<platform::Place>& places) {
5053
PADDLE_ENFORCE_GT(places.size(), 0);
51-
using PtrType = std::unique_ptr<DeviceContext>;
5254
std::set<Place> set;
5355
for (auto& p : places) {
5456
set.insert(p);
@@ -57,26 +59,22 @@ DeviceContextPool::DeviceContextPool(
5759
for (auto& p : set) {
5860
if (platform::is_cpu_place(p)) {
5961
#ifdef PADDLE_WITH_MKLDNN
60-
device_contexts_.emplace(
61-
p, PtrType(new MKLDNNDeviceContext(boost::get<CPUPlace>(p))));
62+
EmplaceDeviceContext<MKLDNNDeviceContext, CPUPlace>(&device_contexts_, p);
6263
#else
63-
device_contexts_.emplace(
64-
p, PtrType(new CPUDeviceContext(boost::get<CPUPlace>(p))));
64+
EmplaceDeviceContext<CPUDeviceContext, CPUPlace>(&device_contexts_, p);
6565
#endif
6666
} else if (platform::is_gpu_place(p)) {
6767
#ifdef PADDLE_WITH_CUDA
68-
device_contexts_.emplace(
69-
p, PtrType(new CUDADeviceContext(boost::get<CUDAPlace>(p))));
68+
EmplaceDeviceContext<CUDADeviceContext, CUDAPlace>(&device_contexts_, p);
7069
#else
7170
PADDLE_THROW(
7271
"'CUDAPlace' is not supported, Please re-compile with WITH_GPU "
7372
"option");
7473
#endif
7574
} else if (platform::is_cuda_pinned_place(p)) {
7675
#ifdef PADDLE_WITH_CUDA
77-
device_contexts_.emplace(
78-
p,
79-
PtrType(new CUDAPinnedDeviceContext(boost::get<CUDAPinnedPlace>(p))));
76+
EmplaceDeviceContext<CUDAPinnedDeviceContext, CUDAPinnedPlace>(
77+
&device_contexts_, p);
8078
#else
8179
PADDLE_THROW(
8280
"'CUDAPlace' is not supported, Please re-compile with WITH_GPU "

paddle/fluid/platform/device_context.h

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ See the License for the specific language governing permissions and
1010
limitations under the License. */
1111
#pragma once
1212

13+
#include <future> // NOLINT
1314
#include <memory>
1415
#include <mutex> // NOLINT
1516
#include <string>
@@ -223,9 +224,6 @@ class DeviceContextPool {
223224
/*! \brief Return handle of single device context. */
224225
platform::DeviceContext* Get(const platform::Place& place);
225226

226-
/*! \brief Return all the device contexts. */
227-
const std::vector<const DeviceContext*> GetAllDeviceContexts() const;
228-
229227
template <typename Place>
230228
const typename DefaultDeviceContextType<Place>::TYPE* GetByPlace(
231229
const Place& place) {
@@ -237,7 +235,8 @@ class DeviceContextPool {
237235

238236
private:
239237
static DeviceContextPool* pool;
240-
std::map<Place, std::unique_ptr<DeviceContext>> device_contexts_;
238+
std::map<Place, std::shared_future<std::unique_ptr<DeviceContext>>>
239+
device_contexts_;
241240
DISABLE_COPY_AND_ASSIGN(DeviceContextPool);
242241
};
243242

0 commit comments

Comments
 (0)