Skip to content

Commit c14d2c6

Browse files
authored
Fix unittest (#29412) (#29437)
* fix tensorrt unittest precision error * fix unittest precision error. test_trt_subgraph_pass && test_trt_dynamic_shape_transformer_prune
1 parent b776434 commit c14d2c6

File tree

2 files changed

+6
-3
lines changed

2 files changed

+6
-3
lines changed

paddle/fluid/inference/tests/api/trt_dynamic_shape_transformer_prune_test.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ void trt_ernie(bool with_fp16, std::vector<float> result) {
126126
run(config, &out_data);
127127

128128
for (size_t i = 0; i < out_data.size(); i++) {
129-
EXPECT_NEAR(result[i], out_data[i], 1e-4);
129+
EXPECT_NEAR(result[i], out_data[i], 2e-3);
130130
}
131131
}
132132

python/paddle/fluid/tests/unittests/ir/inference/test_trt_subgraph_pass.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -308,7 +308,10 @@ def test_check_output(self):
308308
use_gpu = True
309309
if os.path.exists(self.path + "_opt_cache"):
310310
shutil.rmtree(self.path + "_opt_cache")
311-
self.check_output_with_option(use_gpu)
311+
if self.trt_parameters.precision == AnalysisConfig.Precision.Float32:
312+
self.check_output_with_option(use_gpu)
313+
else:
314+
self.check_output_with_option(use_gpu, 1e-3)
312315
self.assertTrue(
313316
PassVersionChecker.IsCompatible('tensorrt_subgraph_pass'))
314317

@@ -567,7 +570,7 @@ def test_check_output(self):
567570
use_gpu = True
568571
if os.path.exists(self.path + "_opt_cache"):
569572
shutil.rmtree(self.path + "_opt_cache")
570-
self.check_output_with_option(use_gpu)
573+
self.check_output_with_option(use_gpu, 1e-3)
571574
self.assertTrue(
572575
PassVersionChecker.IsCompatible('tensorrt_subgraph_pass'))
573576

0 commit comments

Comments
 (0)