Skip to content

Commit eea3673

Browse files
committed
refine test_helper.h
test=develop
1 parent 2b791f1 commit eea3673

File tree

2 files changed

+10
-8
lines changed

2 files changed

+10
-8
lines changed

paddle/fluid/inference/tests/api/tester_helper.h

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -107,12 +107,11 @@ std::unordered_map<std::string, int> GetFuseStatis(PaddlePredictor *predictor,
107107
}
108108

109109
void SetFakeImageInput(std::vector<std::vector<PaddleTensor>> *inputs,
110-
const std::string &dirname,
111-
const bool is_combined = true) {
110+
const std::string &dirname) {
112111
// Set fake_image_data
113112
PADDLE_ENFORCE_EQ(FLAGS_test_all_data, 0, "Only have single batch of data.");
114113
std::vector<std::vector<int64_t>> feed_target_shapes =
115-
GetFeedTargetShapes(dirname, is_combined);
114+
GetFeedTargetShapes(dirname, true, "model", "params");
116115
int dim1 = feed_target_shapes[0][1];
117116
int dim2 = feed_target_shapes[0][2];
118117
int dim3 = feed_target_shapes[0][3];

paddle/fluid/inference/tests/test_helper.h

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -93,15 +93,15 @@ void CheckError(const paddle::framework::LoDTensor& output1,
9393

9494
std::unique_ptr<paddle::framework::ProgramDesc> InitProgram(
9595
paddle::framework::Executor* executor, paddle::framework::Scope* scope,
96-
const std::string& dirname, const bool is_combined = false) {
96+
const std::string& dirname, const bool is_combined = false,
97+
const std::string& prog_filename = "__model_combined__",
98+
const std::string& param_filename = "__params_combined__") {
9799
std::unique_ptr<paddle::framework::ProgramDesc> inference_program;
98100
if (is_combined) {
99101
// All parameters are saved in a single file.
100102
// Hard-coding the file names of program and parameters in unittest.
101103
// The file names should be consistent with that used in Python API
102104
// `fluid.io.save_inference_model`.
103-
std::string prog_filename = "model";
104-
std::string param_filename = "params";
105105
inference_program =
106106
paddle::inference::Load(executor, scope, dirname + "/" + prog_filename,
107107
dirname + "/" + param_filename);
@@ -114,12 +114,15 @@ std::unique_ptr<paddle::framework::ProgramDesc> InitProgram(
114114
}
115115

116116
std::vector<std::vector<int64_t>> GetFeedTargetShapes(
117-
const std::string& dirname, const bool is_combined = false) {
117+
const std::string& dirname, const bool is_combined = false,
118+
const std::string& prog_filename = "__model_combined__",
119+
const std::string& param_filename = "__params_combined__") {
118120
auto place = paddle::platform::CPUPlace();
119121
auto executor = paddle::framework::Executor(place);
120122
auto* scope = new paddle::framework::Scope();
121123

122-
auto inference_program = InitProgram(&executor, scope, dirname, is_combined);
124+
auto inference_program = InitProgram(&executor, scope, dirname, is_combined,
125+
prog_filename, param_filename);
123126
auto& global_block = inference_program->Block(0);
124127

125128
const std::vector<std::string>& feed_target_names =

0 commit comments

Comments
 (0)