Skip to content

Commit 2f9fa9b

Browse files
Merge pull request #10167 from wanghaoshuang/fluid_init
Add init interface for customize devices.
2 parents 4a5bfa8 + 848fb00 commit 2f9fa9b

File tree

4 files changed

+54
-17
lines changed

4 files changed

+54
-17
lines changed

paddle/fluid/framework/init.cc

Lines changed: 36 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@ limitations under the License. */
1515
#include <algorithm>
1616
#include <stdexcept>
1717
#include <string>
18-
#include <vector>
1918

2019
#include "paddle/fluid/framework/init.h"
2120
#include "paddle/fluid/framework/operator.h"
@@ -31,6 +30,7 @@ std::once_flag p2p_init_flag;
3130

3231
void InitGflags(std::vector<std::string> argv) {
3332
std::call_once(gflags_init_flag, [&]() {
33+
argv.insert(argv.begin(), "dummy");
3434
int argc = argv.size();
3535
char **arr = new char *[argv.size()];
3636
std::string line;
@@ -44,20 +44,23 @@ void InitGflags(std::vector<std::string> argv) {
4444
});
4545
}
4646

47-
void InitP2P(int count) {
47+
void InitP2P(std::vector<int> devices) {
4848
#ifdef PADDLE_WITH_CUDA
4949
std::call_once(p2p_init_flag, [&]() {
50+
int count = devices.size();
5051
for (int i = 0; i < count; ++i) {
5152
for (int j = 0; j < count; ++j) {
52-
if (i == j) continue;
53+
if (devices[i] == devices[j]) continue;
5354
int can_acess = -1;
54-
PADDLE_ENFORCE(cudaDeviceCanAccessPeer(&can_acess, i, j),
55-
"Failed to test P2P access.");
55+
PADDLE_ENFORCE(
56+
cudaDeviceCanAccessPeer(&can_acess, devices[i], devices[j]),
57+
"Failed to test P2P access.");
5658
if (can_acess != 1) {
57-
LOG(WARNING) << "Cannot enable P2P access from " << i << " to " << j;
59+
LOG(WARNING) << "Cannot enable P2P access from " << devices[i]
60+
<< " to " << devices[j];
5861
} else {
59-
cudaSetDevice(i);
60-
cudaDeviceEnablePeerAccess(j, 0);
62+
cudaSetDevice(devices[i]);
63+
cudaDeviceEnablePeerAccess(devices[j], 0);
6164
}
6265
}
6366
}
@@ -67,11 +70,26 @@ void InitP2P(int count) {
6770

6871
void InitDevices(bool init_p2p) {
6972
/*Init all available devices by default */
73+
std::vector<int> devices;
74+
#ifdef PADDLE_WITH_CUDA
75+
try {
76+
int count = platform::GetCUDADeviceCount();
77+
for (int i = 0; i < count; ++i) {
78+
devices.push_back(i);
79+
}
80+
} catch (const std::exception &exp) {
81+
LOG(WARNING) << "Compiled with WITH_GPU, but no GPU found in runtime.";
82+
}
83+
#else
84+
LOG(WARNING)
85+
<< "'CUDA' is not supported, Please re-compile with WITH_GPU option";
86+
#endif
87+
InitDevices(init_p2p, devices);
88+
}
7089

90+
void InitDevices(bool init_p2p, const std::vector<int> devices) {
7191
std::vector<platform::Place> places;
72-
places.emplace_back(platform::CPUPlace());
7392
int count = 0;
74-
7593
#ifdef PADDLE_WITH_CUDA
7694
try {
7795
count = platform::GetCUDADeviceCount();
@@ -83,12 +101,17 @@ void InitDevices(bool init_p2p) {
83101
<< "'CUDA' is not supported, Please re-compile with WITH_GPU option";
84102
#endif
85103

86-
for (int i = 0; i < count; ++i) {
87-
places.emplace_back(platform::CUDAPlace(i));
104+
for (size_t i = 0; i < devices.size(); ++i) {
105+
if (devices[i] >= count || devices[i] < 0) {
106+
LOG(WARNING) << "Invalid devices id.";
107+
continue;
108+
}
109+
places.emplace_back(platform::CUDAPlace(devices[i]));
88110
}
89111
if (init_p2p) {
90-
InitP2P(count);
112+
InitP2P(devices);
91113
}
114+
places.emplace_back(platform::CPUPlace());
92115
platform::DeviceContextPool::Init(places);
93116
}
94117

paddle/fluid/framework/init.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,5 +28,7 @@ void InitGLOG(const std::string &prog_name);
2828

2929
void InitDevices(bool init_p2p);
3030

31+
void InitDevices(bool init_p2p, const std::vector<int> devices);
32+
3133
} // namespace framework
3234
} // namespace paddle

paddle/fluid/inference/io.cc

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,17 +16,29 @@ limitations under the License. */
1616

1717
#include <algorithm>
1818
#include <fstream>
19+
#include <vector>
1920
#include "paddle/fluid/framework/block_desc.h"
2021
#include "paddle/fluid/framework/feed_fetch_type.h"
2122
#include "paddle/fluid/framework/op_registry.h"
2223
#include "paddle/fluid/pybind/pybind.h"
2324

25+
DEFINE_string(devices, "", "The devices to be used which is joined by comma.");
26+
DEFINE_bool(init_p2p, false, "Whether to init p2p.");
27+
2428
namespace paddle {
2529
namespace inference {
2630

27-
// Temporarily add this function for exposing framework::InitDevices() when
28-
// linking the inference shared library.
29-
void Init(bool init_p2p) { framework::InitDevices(init_p2p); }
31+
void Init(const std::vector<std::string> argv) {
32+
framework::InitGflags(argv);
33+
// init devices
34+
std::vector<int> devices;
35+
std::string token;
36+
std::istringstream tokenStream(FLAGS_devices);
37+
while (std::getline(tokenStream, token, ',')) {
38+
devices.push_back(std::stoi(token));
39+
}
40+
framework::InitDevices(FLAGS_init_p2p, devices);
41+
}
3042

3143
void ReadBinaryFile(const std::string& filename, std::string* contents) {
3244
std::ifstream fin(filename, std::ios::in | std::ios::binary);

paddle/fluid/inference/io.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ limitations under the License. */
2525
namespace paddle {
2626
namespace inference {
2727

28-
void Init(bool init_p2p);
28+
void Init(const std::vector<std::string> argv);
2929

3030
void LoadPersistables(framework::Executor* executor, framework::Scope* scope,
3131
const framework::ProgramDesc& main_program,

0 commit comments

Comments
 (0)