Skip to content

Commit 4bf6817

Browse files
committed
fix gpu load model
the parameters will load from CPUPlace, that will keep copying data between CPU and GPU places. test=develop
1 parent 1722678 commit 4bf6817

File tree

4 files changed

+26
-11
lines changed

4 files changed

+26
-11
lines changed

paddle/fluid/inference/analysis/argument.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@ struct Argument {
116116
std::vector<std::string>);
117117

118118
DECL_ARGUMENT_FIELD(use_gpu, UseGPU, bool);
119+
DECL_ARGUMENT_FIELD(gpu_device_id, GPUDeviceId, int);
119120
DECL_ARGUMENT_FIELD(use_tensorrt, UseTensorRT, bool);
120121
DECL_ARGUMENT_FIELD(tensorrt_node_teller, TensorRtNodeTeller,
121122
std::function<bool(const framework::ir::Node*)>);

paddle/fluid/inference/analysis/passes/ir_graph_build_pass.cc

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,15 +30,28 @@ void IrGraphBuildPass::RunImpl(Argument *argument) {
3030
if (!argument->scope_valid()) {
3131
argument->SetScope(new framework::Scope);
3232
}
33+
PADDLE_ENFORCE(argument->use_gpu_valid());
34+
35+
// The load program should run on the same device with the inference program,
36+
// so that the parameters will on the same device, or they will keep copying
37+
// between difference devices.
38+
platform::Place place;
39+
if (argument->use_gpu()) {
40+
PADDLE_ENFORCE(argument->gpu_device_id_valid());
41+
place = platform::CUDAPlace(argument->gpu_device_id());
42+
} else {
43+
place = platform::CPUPlace();
44+
}
3345

3446
if (argument->model_dir_valid()) {
35-
auto program = LoadModel(argument->model_dir(), argument->scope_ptr());
47+
auto program =
48+
LoadModel(argument->model_dir(), argument->scope_ptr(), place);
3649
argument->SetMainProgram(program.release());
3750
} else if (argument->model_program_path_valid() &&
3851
argument->model_params_path_valid()) {
3952
auto program =
4053
LoadModel(argument->model_program_path(), argument->model_params_path(),
41-
argument->scope_ptr());
54+
argument->scope_ptr(), place);
4255
argument->SetMainProgram(program.release());
4356
} else {
4457
PADDLE_THROW(
@@ -52,16 +65,15 @@ void IrGraphBuildPass::RunImpl(Argument *argument) {
5265
}
5366

5467
std::unique_ptr<framework::ProgramDesc> IrGraphBuildPass::LoadModel(
55-
const std::string &path, framework::Scope *scope) {
56-
platform::CPUPlace place;
68+
const std::string &path, framework::Scope *scope,
69+
const platform::Place &place) {
5770
framework::Executor exe(place);
5871
return Load(&exe, scope, path);
5972
}
6073

6174
std::unique_ptr<framework::ProgramDesc> IrGraphBuildPass::LoadModel(
6275
const std::string &program_path, const std::string &params_path,
63-
framework::Scope *scope) {
64-
platform::CPUPlace place;
76+
framework::Scope *scope, const platform::Place &place) {
6577
framework::Executor exe(place);
6678
return Load(&exe, scope, program_path, params_path);
6779
}

paddle/fluid/inference/analysis/passes/ir_graph_build_pass.h

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,13 @@ class IrGraphBuildPass : public AnalysisPass {
3232
std::string repr() const override;
3333

3434
private:
35-
std::unique_ptr<framework::ProgramDesc> LoadModel(const std::string &path,
36-
framework::Scope *scope);
35+
std::unique_ptr<framework::ProgramDesc> LoadModel(
36+
const std::string &path, framework::Scope *scope,
37+
const boost::variant<CUDAPlace, CPUPlace, CUDAPinnedPlace> &place);
3738
std::unique_ptr<framework::ProgramDesc> LoadModel(
3839
const std::string &program_path, const std::string &params_path,
39-
framework::Scope *scope);
40+
framework::Scope *scope,
41+
const boost::variant<CUDAPlace, CPUPlace, CUDAPinnedPlace> &place);
4042

4143
std::string model_binary_str_;
4244
};

paddle/fluid/inference/api/analysis_predictor.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,7 @@ void AnalysisPredictor::OptimizeInferenceProgram() {
285285
status_program_optimized_ = true;
286286

287287
argument_.SetUseGPU(config_.use_gpu);
288+
argument_.SetGPUDeviceId(config_.device);
288289
// Analyze inference_program
289290
if (!config_.model_dir.empty()) {
290291
argument_.SetModelDir(config_.model_dir);
@@ -491,8 +492,7 @@ bool AnalysisPredictor::LoadParameters() {
491492
}
492493

493494
// Use NaiveExecutor to Load parameters.
494-
platform::CPUPlace place;
495-
framework::NaiveExecutor e(place);
495+
framework::NaiveExecutor e(place_);
496496
e.Prepare(scope_.get(), *load_program, 0, false);
497497
e.Run();
498498
VLOG(3) << "get " << scope_->LocalVarNames().size() << " vars after load";

0 commit comments

Comments
 (0)