Skip to content

Commit 90ffe74

Browse files
authored
Merge pull request #15546 from NHZlX/fix_trt_utest_random_failed
fix trt models utest failed.
2 parents c744922 + 95b98f2 commit 90ffe74

File tree

2 files changed

+7
-11
lines changed

2 files changed

+7
-11
lines changed

paddle/fluid/inference/tests/api/tester_helper.h

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -56,13 +56,6 @@ DECLARE_int32(paddle_num_threads);
5656
namespace paddle {
5757
namespace inference {
5858

59-
float Random(float low, float high) {
60-
static std::random_device rd;
61-
static std::mt19937 mt(rd());
62-
std::uniform_real_distribution<double> dist(low, high);
63-
return dist(mt);
64-
}
65-
6659
void PrintConfig(const PaddlePredictor::Config *config, bool use_analysis) {
6760
const auto *analysis_config =
6861
reinterpret_cast<const AnalysisConfig *>(config);
@@ -146,7 +139,8 @@ void SetFakeImageInput(std::vector<std::vector<PaddleTensor>> *inputs,
146139
const std::string &dirname, bool is_combined = true,
147140
std::string model_filename = "model",
148141
std::string params_filename = "params",
149-
const std::vector<std::string> *feed_names = nullptr) {
142+
const std::vector<std::string> *feed_names = nullptr,
143+
const int continuous_inuput_index = 0) {
150144
// Set fake_image_data
151145
PADDLE_ENFORCE_EQ(FLAGS_test_all_data, 0, "Only have single batch of data.");
152146
std::vector<std::vector<int64_t>> feed_target_shapes = GetFeedTargetShapes(
@@ -183,7 +177,8 @@ void SetFakeImageInput(std::vector<std::vector<PaddleTensor>> *inputs,
183177
float *input_data = static_cast<float *>(input.data.data());
184178
// fill input data, for profile easily, do not use random data here.
185179
for (size_t j = 0; j < len; ++j) {
186-
*(input_data + j) = Random(0.0, 1.0) / 10.;
180+
*(input_data + j) =
181+
static_cast<float>((j + continuous_inuput_index) % len) / len;
187182
}
188183
}
189184
(*inputs).emplace_back(input_slots);

paddle/fluid/inference/tests/api/trt_models_tester.cc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -119,9 +119,10 @@ void compare_continuous_input(std::string model_dir, bool use_tensorrt) {
119119
std::vector<std::vector<PaddleTensor>> inputs_all;
120120
if (!FLAGS_prog_filename.empty() && !FLAGS_param_filename.empty()) {
121121
SetFakeImageInput(&inputs_all, model_dir, true, FLAGS_prog_filename,
122-
FLAGS_param_filename);
122+
FLAGS_param_filename, nullptr, i);
123123
} else {
124-
SetFakeImageInput(&inputs_all, model_dir, false, "__model__", "");
124+
SetFakeImageInput(&inputs_all, model_dir, false, "__model__", "", nullptr,
125+
i);
125126
}
126127
CompareNativeAndAnalysis(native_pred.get(), analysis_pred.get(),
127128
inputs_all);

0 commit comments

Comments
 (0)