Skip to content

Commit a186b53

Browse files
authored
add init_gflags interface (#5193)
* add init_gflags interface * refine code * follow comments
1 parent 6c8dce9 commit a186b53

File tree

2 files changed

+31
-0
lines changed

2 files changed

+31
-0
lines changed

paddle/pybind/pybind.cc

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ limitations under the License. */
1414

1515
#include "paddle/pybind/protobuf.h"
1616

17+
#include <mutex> // for call_once
18+
#include "gflags/gflags.h"
1719
#include "paddle/framework/backward.h"
1820
#include "paddle/framework/executor.h"
1921
#include "paddle/framework/feed_fetch_method.h"
@@ -45,6 +47,24 @@ static size_t UniqueIntegerGenerator() {
4547
return generator.fetch_add(1);
4648
}
4749

50+
std::once_flag gflags_init_flag;
51+
52+
// TODO(qijun) move init gflags to init.cc
53+
void InitGflags(std::vector<std::string> &argv) {
54+
std::call_once(gflags_init_flag, [&]() {
55+
int argc = argv.size();
56+
char **arr = new char *[argv.size()];
57+
std::string line;
58+
for (size_t i = 0; i < argv.size(); i++) {
59+
arr[i] = &argv[i][0];
60+
line += argv[i];
61+
line += ' ';
62+
}
63+
google::ParseCommandLineFlags(&argc, &arr, true);
64+
VLOG(1) << "Init commandline: " << line;
65+
});
66+
}
67+
4868
bool IsCompileGPU() {
4969
#ifndef PADDLE_WITH_CUDA
5070
return false;
@@ -483,6 +503,7 @@ All parameter, weight, gradient are variables in Paddle.
483503
});
484504

485505
m.def("unique_integer", UniqueIntegerGenerator);
506+
m.def("init_gflags", InitGflags);
486507

487508
m.def("is_compile_gpu", IsCompileGPU);
488509
m.def("set_feed_variable", framework::SetFeedVariable);
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,11 @@
1+
import sys
2+
import core
13
__all__ = ['proto']
4+
argv = []
5+
if core.is_compile_gpu():
6+
argv = list(sys.argv) + [
7+
"--tryfromenv=fraction_of_gpu_memory_to_use,use_pinned_memory"
8+
]
9+
else:
10+
argv = list(sys.argv) + ["--tryfromenv=use_pinned_memory"]
11+
core.init_gflags(argv)

0 commit comments

Comments
 (0)