@@ -32,23 +32,25 @@ platform::DeviceContext* DeviceContextPool::Get(const platform::Place& place) {
32
32
" 'Place' is not supported, Please re-compile with WITH_GPU "
33
33
" option" );
34
34
}
35
- return it->second .get ();
35
+ return it->second .get (). get () ;
36
36
}
37
37
38
- const std::vector<const DeviceContext*>
39
- DeviceContextPool::GetAllDeviceContexts () const {
40
- std::vector<const DeviceContext*> all_device_ctx;
41
- all_device_ctx.reserve (device_contexts_.size ());
42
- for (auto & dev_ctx : device_contexts_) {
43
- all_device_ctx.emplace_back (dev_ctx.second .get ());
44
- }
45
- return all_device_ctx;
38
+ template <typename DevCtx, typename PlaceType>
39
+ inline void EmplaceDeviceContext (
40
+ std::map<Place, std::shared_future<std::unique_ptr<DeviceContext>>>*
41
+ map_ptr,
42
+ platform::Place p) {
43
+ using PtrType = std::unique_ptr<DeviceContext>;
44
+ map_ptr->emplace (p, std::async (std::launch::deferred, [=] {
45
+ // lazy evaluation. i.e., only create device context at
46
+ // first `Get`
47
+ return PtrType (new DevCtx (boost::get<PlaceType>(p)));
48
+ }));
46
49
}
47
50
48
51
DeviceContextPool::DeviceContextPool (
49
52
const std::vector<platform::Place>& places) {
50
53
PADDLE_ENFORCE_GT (places.size (), 0 );
51
- using PtrType = std::unique_ptr<DeviceContext>;
52
54
std::set<Place> set;
53
55
for (auto & p : places) {
54
56
set.insert (p);
@@ -57,26 +59,22 @@ DeviceContextPool::DeviceContextPool(
57
59
for (auto & p : set) {
58
60
if (platform::is_cpu_place (p)) {
59
61
#ifdef PADDLE_WITH_MKLDNN
60
- device_contexts_.emplace (
61
- p, PtrType (new MKLDNNDeviceContext (boost::get<CPUPlace>(p))));
62
+ EmplaceDeviceContext<MKLDNNDeviceContext, CPUPlace>(&device_contexts_, p);
62
63
#else
63
- device_contexts_.emplace (
64
- p, PtrType (new CPUDeviceContext (boost::get<CPUPlace>(p))));
64
+ EmplaceDeviceContext<CPUDeviceContext, CPUPlace>(&device_contexts_, p);
65
65
#endif
66
66
} else if (platform::is_gpu_place (p)) {
67
67
#ifdef PADDLE_WITH_CUDA
68
- device_contexts_.emplace (
69
- p, PtrType (new CUDADeviceContext (boost::get<CUDAPlace>(p))));
68
+ EmplaceDeviceContext<CUDADeviceContext, CUDAPlace>(&device_contexts_, p);
70
69
#else
71
70
PADDLE_THROW (
72
71
" 'CUDAPlace' is not supported, Please re-compile with WITH_GPU "
73
72
" option" );
74
73
#endif
75
74
} else if (platform::is_cuda_pinned_place (p)) {
76
75
#ifdef PADDLE_WITH_CUDA
77
- device_contexts_.emplace (
78
- p,
79
- PtrType (new CUDAPinnedDeviceContext (boost::get<CUDAPinnedPlace>(p))));
76
+ EmplaceDeviceContext<CUDAPinnedDeviceContext, CUDAPinnedPlace>(
77
+ &device_contexts_, p);
80
78
#else
81
79
PADDLE_THROW (
82
80
" 'CUDAPlace' is not supported, Please re-compile with WITH_GPU "
0 commit comments