Skip to content

Commit 0528067

Browse files
authored
Merge pull request #14500 from NHZlX/refine_trt
Fix gpu load model and demo_ci on trt
2 parents a8d3aaa + a4dc1d4 commit 0528067

File tree

11 files changed

+53
-28
lines changed

11 files changed

+53
-28
lines changed

paddle/fluid/inference/analysis/CMakeLists.txt

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,17 @@ set(analysis_deps # analysis_deps can be extended accross the project
77
add_subdirectory(ir_passes)
88
add_subdirectory(passes)
99

10-
cc_library(ir_pass_manager SRCS ir_pass_manager.cc DEPS graph pass ${INFER_IR_PASSES})
10+
cc_library(analysis_helper SRCS helper.cc DEPS framework_proto proto_desc graph paddle_fluid_api)
11+
12+
cc_library(ir_pass_manager SRCS ir_pass_manager.cc DEPS graph pass ${INFER_IR_PASSES} analysis_helper)
1113

1214
cc_library(argument SRCS argument.cc DEPS scope proto_desc)
1315
cc_library(analysis_pass SRCS analysis_pass.cc DEPS proto_desc)
1416

1517
cc_library(analysis SRCS
1618
analyzer.cc
17-
helper.cc
1819
analysis_pass
19-
DEPS ${analysis_deps}
20+
DEPS ${analysis_deps} analysis_helper
2021
)
2122

2223
cc_test(test_dot SRCS dot_tester.cc DEPS analysis)

paddle/fluid/inference/analysis/analyzer_tester.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ TEST(Analyzer, analysis_without_tensorrt) {
3030
Argument argument;
3131
argument.SetModelDir(FLAGS_inference_model_dir);
3232
argument.SetIrAnalysisPasses({"infer_clean_graph_pass"});
33+
argument.SetUseGPU(false);
3334

3435
Analyzer analyser;
3536
analyser.Run(&argument);
@@ -41,6 +42,7 @@ TEST(Analyzer, analysis_with_tensorrt) {
4142
argument.SetTensorRtWorkspaceSize(1 << 20);
4243
argument.SetModelDir(FLAGS_inference_model_dir);
4344
argument.SetIrAnalysisPasses({"infer_clean_graph_pass"});
45+
argument.SetUseGPU(false);
4446

4547
Analyzer analyser;
4648
analyser.Run(&argument);

paddle/fluid/inference/analysis/argument.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@ struct Argument {
116116
std::vector<std::string>);
117117

118118
DECL_ARGUMENT_FIELD(use_gpu, UseGPU, bool);
119+
DECL_ARGUMENT_FIELD(gpu_device_id, GPUDeviceId, int);
119120
DECL_ARGUMENT_FIELD(use_tensorrt, UseTensorRT, bool);
120121
DECL_ARGUMENT_FIELD(tensorrt_node_teller, TensorRtNodeTeller,
121122
std::function<bool(const framework::ir::Node*)>);

paddle/fluid/inference/analysis/ir_passes/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,6 @@ set(analysis_deps ${analysis_deps}
44
subgraph_detector tensorrt_subgraph_pass
55
CACHE INTERNAL "")
66

7+
set(pass_file ${PADDLE_BINARY_DIR}/paddle/fluid/inference/api/paddle_inference_pass.h)
8+
file(APPEND ${pass_file} "USE_PASS(tensorrt_subgraph_pass);\n")
79
set(INFER_IR_PASSES ${INFER_IR_PASSES} tensorrt_subgraph_pass CACHE INTERNAL "")

paddle/fluid/inference/analysis/passes/ir_graph_build_pass.cc

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,15 +30,28 @@ void IrGraphBuildPass::RunImpl(Argument *argument) {
3030
if (!argument->scope_valid()) {
3131
argument->SetScope(new framework::Scope);
3232
}
33+
PADDLE_ENFORCE(argument->use_gpu_valid());
34+
35+
// The load program should run on the same device with the inference program,
36+
// so that the parameters will on the same device, or they will keep copying
37+
// between difference devices.
38+
platform::Place place;
39+
if (argument->use_gpu()) {
40+
PADDLE_ENFORCE(argument->gpu_device_id_valid());
41+
place = platform::CUDAPlace(argument->gpu_device_id());
42+
} else {
43+
place = platform::CPUPlace();
44+
}
3345

3446
if (argument->model_dir_valid()) {
35-
auto program = LoadModel(argument->model_dir(), argument->scope_ptr());
47+
auto program =
48+
LoadModel(argument->model_dir(), argument->scope_ptr(), place);
3649
argument->SetMainProgram(program.release());
3750
} else if (argument->model_program_path_valid() &&
3851
argument->model_params_path_valid()) {
3952
auto program =
4053
LoadModel(argument->model_program_path(), argument->model_params_path(),
41-
argument->scope_ptr());
54+
argument->scope_ptr(), place);
4255
argument->SetMainProgram(program.release());
4356
} else {
4457
PADDLE_THROW(
@@ -52,16 +65,15 @@ void IrGraphBuildPass::RunImpl(Argument *argument) {
5265
}
5366

5467
std::unique_ptr<framework::ProgramDesc> IrGraphBuildPass::LoadModel(
55-
const std::string &path, framework::Scope *scope) {
56-
platform::CPUPlace place;
68+
const std::string &path, framework::Scope *scope,
69+
const platform::Place &place) {
5770
framework::Executor exe(place);
5871
return Load(&exe, scope, path);
5972
}
6073

6174
std::unique_ptr<framework::ProgramDesc> IrGraphBuildPass::LoadModel(
6275
const std::string &program_path, const std::string &params_path,
63-
framework::Scope *scope) {
64-
platform::CPUPlace place;
76+
framework::Scope *scope, const platform::Place &place) {
6577
framework::Executor exe(place);
6678
return Load(&exe, scope, program_path, params_path);
6779
}

paddle/fluid/inference/analysis/passes/ir_graph_build_pass.h

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include <string>
1818
#include "paddle/fluid/framework/scope.h"
1919
#include "paddle/fluid/inference/analysis/analysis_pass.h"
20+
#include "paddle/fluid/platform/place.h"
2021

2122
namespace paddle {
2223
namespace inference {
@@ -32,11 +33,12 @@ class IrGraphBuildPass : public AnalysisPass {
3233
std::string repr() const override;
3334

3435
private:
35-
std::unique_ptr<framework::ProgramDesc> LoadModel(const std::string &path,
36-
framework::Scope *scope);
36+
std::unique_ptr<framework::ProgramDesc> LoadModel(
37+
const std::string &path, framework::Scope *scope,
38+
const platform::Place &place);
3739
std::unique_ptr<framework::ProgramDesc> LoadModel(
3840
const std::string &program_path, const std::string &params_path,
39-
framework::Scope *scope);
41+
framework::Scope *scope, const platform::Place &place);
4042

4143
std::string model_binary_str_;
4244
};

paddle/fluid/inference/api/CMakeLists.txt

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,10 @@ endif()
2727
cc_library(reset_tensor_array SRCS details/reset_tensor_array.cc DEPS lod_tensor scope)
2828
cc_library(analysis_config SRCS analysis_config.cc DEPS lod_tensor paddle_pass_builder)
2929
cc_library(paddle_pass_builder SRCS paddle_pass_builder.cc)
30-
cc_library(paddle_inference_api SRCS api.cc api_impl.cc helper.cc DEPS lod_tensor scope paddle_pass_builder reset_tensor_array analysis_config analysis_config paddle_pass_builder)
31-
cc_library(analysis_predictor SRCS analysis_predictor.cc DEPS paddle_inference_api analysis naive_executor zero_copy_tensor reset_tensor_array analysis_config paddle_pass_builder)
32-
cc_library(zero_copy_tensor SRCS details/zero_copy_tensor.cc DEPS paddle_inference_api)
33-
cc_library(zero_copy_tensor_dummy SRCS details/zero_copy_tensor_dummy.cc DEPS paddle_inference_api)
34-
30+
cc_library(analysis_predictor SRCS analysis_predictor.cc DEPS paddle_inference_api analysis naive_executor zero_copy_tensor reset_tensor_array analysis_config paddle_pass_builder ir_pass_manager)
31+
cc_library(zero_copy_tensor SRCS details/zero_copy_tensor.cc DEPS scope lod_tensor enforce)
32+
cc_library(zero_copy_tensor_dummy SRCS details/zero_copy_tensor_dummy.cc)
33+
cc_library(paddle_inference_api SRCS api.cc api_impl.cc helper.cc DEPS lod_tensor scope paddle_pass_builder reset_tensor_array analysis_config analysis_config paddle_pass_builder DEPS zero_copy_tensor)
3534

3635
cc_test(test_paddle_inference_api
3736
SRCS api_tester.cc

paddle/fluid/inference/api/analysis_predictor.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,7 @@ void AnalysisPredictor::OptimizeInferenceProgram() {
285285
status_program_optimized_ = true;
286286

287287
argument_.SetUseGPU(config_.use_gpu);
288+
argument_.SetGPUDeviceId(config_.device);
288289
// Analyze inference_program
289290
if (!config_.model_dir.empty()) {
290291
argument_.SetModelDir(config_.model_dir);
@@ -491,8 +492,7 @@ bool AnalysisPredictor::LoadParameters() {
491492
}
492493

493494
// Use NaiveExecutor to Load parameters.
494-
platform::CPUPlace place;
495-
framework::NaiveExecutor e(place);
495+
framework::NaiveExecutor e(place_);
496496
e.Prepare(scope_.get(), *load_program, 0, false);
497497
e.Run();
498498
VLOG(3) << "get " << scope_->LocalVarNames().size() << " vars after load";

paddle/fluid/inference/api/paddle_pass_builder.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,8 +116,12 @@ class CpuPassStrategy : public PassStrategy {
116116
class GpuPassStrategy : public PassStrategy {
117117
public:
118118
GpuPassStrategy() : PassStrategy({}) {
119+
// TODO(NHZlX) Problem with Data synchronization between GPU and CPU
120+
// When running in GPU mode, the parameters are all on GPU. But the
121+
// opearations of "conv_bn_fuse_pass" are on CPU.
119122
passes_.assign({
120-
"infer_clean_graph_pass", "conv_bn_fuse_pass",
123+
"infer_clean_graph_pass",
124+
// "infer_clean_graph_pass", "conv_bn_fuse_pass",
121125
});
122126
}
123127

paddle/fluid/inference/tests/api/CMakeLists.txt

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
set(INFERENCE_EXTRA_DEPS paddle_inference_api paddle_fluid_api ir_pass_manager analysis_predictor)
22

3+
if(WITH_GPU AND TENSORRT_FOUND)
4+
set(INFERENCE_EXTRA_DEPS ${INFERENCE_EXTRA_DEPS} analysis ${analysis_deps} ir_pass_manager analysis_predictor)
5+
endif()
6+
37
function(download_model install_dir model_name)
48
if (NOT EXISTS ${install_dir})
59
inference_download_and_uncompress(${install_dir} ${INFERENCE_URL} ${model_name})
@@ -75,11 +79,11 @@ endif()
7579
inference_analysis_api_test(test_analyzer_ocr ${OCR_INSTALL_DIR} analyzer_vis_tester.cc)
7680

7781
# resnet50
78-
inference_analysis_api_test_with_fake_data(test_analyzer_resnet50
82+
inference_analysis_api_test_with_fake_data(test_analyzer_resnet50
7983
"${INFERENCE_DEMO_INSTALL_DIR}/resnet50" analyzer_resnet50_tester.cc "resnet50_model.tar.gz")
8084

8185
# mobilenet with depthwise_conv op
82-
inference_analysis_api_test_with_fake_data(test_analyzer_mobilenet
86+
inference_analysis_api_test_with_fake_data(test_analyzer_mobilenet
8387
"${INFERENCE_DEMO_INSTALL_DIR}/mobilenet_depthwise_conv" analyzer_resnet50_tester.cc "mobilenet_model.tar.gz")
8488

8589
# anakin
@@ -89,15 +93,15 @@ if (WITH_ANAKIN AND WITH_MKL) # only needed in CI
8993
set(ANAKIN_RNN1_INSTALL_DIR "${ANAKIN_INSTALL_DIR}/rnn1")
9094
inference_download(${ANAKIN_RNN1_INSTALL_DIR} ${INFERENCE_URL} "anakin_test%2Fditu_rnn.anakin2.model.bin")
9195
inference_download(${ANAKIN_RNN1_INSTALL_DIR} ${INFERENCE_URL} "anakin_test%2Fditu_rnn_data.txt")
92-
cc_test(test_anakin_rnn1 SRCS anakin_rnn1_tester.cc
93-
ARGS --model=${ANAKIN_RNN1_INSTALL_DIR}/anakin_test%2Fditu_rnn.anakin2.model.bin
96+
cc_test(test_anakin_rnn1 SRCS anakin_rnn1_tester.cc
97+
ARGS --model=${ANAKIN_RNN1_INSTALL_DIR}/anakin_test%2Fditu_rnn.anakin2.model.bin
9498
--datapath=${ANAKIN_RNN1_INSTALL_DIR}/anakin_test%2Fditu_rnn_data.txt
9599
DEPS inference_anakin_api_shared SERIAL)
96100
# anakin mobilenet
97101
if(WITH_GPU)
98102
set(ANAKIN_MOBILENET_INSTALL_DIR "${ANAKIN_INSTALL_DIR}/mobilenet")
99103
inference_download(${ANAKIN_MOBILENET_INSTALL_DIR} ${INFERENCE_URL} "mobilenet_v2.anakin.bin")
100-
cc_test(test_anakin_mobilenet SRCS anakin_mobilenet_tester.cc
104+
cc_test(test_anakin_mobilenet SRCS anakin_mobilenet_tester.cc
101105
ARGS --model=${ANAKIN_MOBILENET_INSTALL_DIR}/mobilenet_v2.anakin.bin
102106
DEPS inference_anakin_api_shared dynload_cuda SERIAL)
103107
endif()
@@ -109,6 +113,6 @@ if(WITH_GPU AND TENSORRT_FOUND)
109113
inference_download_and_uncompress(${TRT_MODEL_INSTALL_DIR} ${INFERENCE_URL}/tensorrt_test "trt_test_models.tar.gz")
110114
endif()
111115
inference_analysis_test(test_trt_models SRCS trt_models_tester.cc
112-
EXTRA_DEPS ${INFERENCE_EXTRA_DEPS} analysis ${analysis_deps} ir_pass_manager analysis_predictor
116+
EXTRA_DEPS ${INFERENCE_EXTRA_DEPS}
113117
ARGS --infer_model=${TRT_MODEL_INSTALL_DIR}/trt_test_models SERIAL)
114118
endif()

0 commit comments

Comments
 (0)