@@ -15,7 +15,6 @@ limitations under the License. */
15
15
#include < algorithm>
16
16
#include < stdexcept>
17
17
#include < string>
18
- #include < vector>
19
18
20
19
#include " paddle/fluid/framework/init.h"
21
20
#include " paddle/fluid/framework/operator.h"
@@ -31,6 +30,7 @@ std::once_flag p2p_init_flag;
31
30
32
31
void InitGflags (std::vector<std::string> argv) {
33
32
std::call_once (gflags_init_flag, [&]() {
33
+ argv.insert (argv.begin (), " dummy" );
34
34
int argc = argv.size ();
35
35
char **arr = new char *[argv.size ()];
36
36
std::string line;
@@ -44,20 +44,23 @@ void InitGflags(std::vector<std::string> argv) {
44
44
});
45
45
}
46
46
47
- void InitP2P (int count ) {
47
+ void InitP2P (std::vector< int > devices ) {
48
48
#ifdef PADDLE_WITH_CUDA
49
49
std::call_once (p2p_init_flag, [&]() {
50
+ int count = devices.size ();
50
51
for (int i = 0 ; i < count; ++i) {
51
52
for (int j = 0 ; j < count; ++j) {
52
- if (i == j ) continue ;
53
+ if (devices[i] == devices[j] ) continue ;
53
54
int can_acess = -1 ;
54
- PADDLE_ENFORCE (cudaDeviceCanAccessPeer (&can_acess, i, j),
55
- " Failed to test P2P access." );
55
+ PADDLE_ENFORCE (
56
+ cudaDeviceCanAccessPeer (&can_acess, devices[i], devices[j]),
57
+ " Failed to test P2P access." );
56
58
if (can_acess != 1 ) {
57
- LOG (WARNING) << " Cannot enable P2P access from " << i << " to " << j;
59
+ LOG (WARNING) << " Cannot enable P2P access from " << devices[i]
60
+ << " to " << devices[j];
58
61
} else {
59
- cudaSetDevice (i );
60
- cudaDeviceEnablePeerAccess (j , 0 );
62
+ cudaSetDevice (devices[i] );
63
+ cudaDeviceEnablePeerAccess (devices[j] , 0 );
61
64
}
62
65
}
63
66
}
@@ -67,11 +70,26 @@ void InitP2P(int count) {
67
70
68
71
void InitDevices (bool init_p2p) {
69
72
/* Init all available devices by default */
73
+ std::vector<int > devices;
74
+ #ifdef PADDLE_WITH_CUDA
75
+ try {
76
+ int count = platform::GetCUDADeviceCount ();
77
+ for (int i = 0 ; i < count; ++i) {
78
+ devices.push_back (i);
79
+ }
80
+ } catch (const std::exception &exp) {
81
+ LOG (WARNING) << " Compiled with WITH_GPU, but no GPU found in runtime." ;
82
+ }
83
+ #else
84
+ LOG (WARNING)
85
+ << " 'CUDA' is not supported, Please re-compile with WITH_GPU option" ;
86
+ #endif
87
+ InitDevices (init_p2p, devices);
88
+ }
70
89
90
+ void InitDevices (bool init_p2p, const std::vector<int > devices) {
71
91
std::vector<platform::Place> places;
72
- places.emplace_back (platform::CPUPlace ());
73
92
int count = 0 ;
74
-
75
93
#ifdef PADDLE_WITH_CUDA
76
94
try {
77
95
count = platform::GetCUDADeviceCount ();
@@ -83,12 +101,17 @@ void InitDevices(bool init_p2p) {
83
101
<< " 'CUDA' is not supported, Please re-compile with WITH_GPU option" ;
84
102
#endif
85
103
86
- for (int i = 0 ; i < count; ++i) {
87
- places.emplace_back (platform::CUDAPlace (i));
104
+ for (size_t i = 0 ; i < devices.size (); ++i) {
105
+ if (devices[i] >= count || devices[i] < 0 ) {
106
+ LOG (WARNING) << " Invalid devices id." ;
107
+ continue ;
108
+ }
109
+ places.emplace_back (platform::CUDAPlace (devices[i]));
88
110
}
89
111
if (init_p2p) {
90
- InitP2P (count );
112
+ InitP2P (devices );
91
113
}
114
+ places.emplace_back (platform::CPUPlace ());
92
115
platform::DeviceContextPool::Init (places);
93
116
}
94
117
0 commit comments