Skip to content

Commit 64a08f8

Browse files
committed
increase the test batch
1 parent c13efe0 commit 64a08f8

File tree

2 files changed

+5
-5
lines changed

2 files changed

+5
-5
lines changed

paddle/fluid/inference/tensorrt/convert/test_elementwise_op.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ namespace tensorrt {
2323
TEST(elementwise_op, add_weight_test) {
2424
std::unordered_set<std::string> parameters({"elementwise_add-Y"});
2525
framework::Scope scope;
26-
TRTConvertValidation validator(1, parameters, scope, 1 << 15);
26+
TRTConvertValidation validator(10, parameters, scope, 1 << 15);
2727
validator.DeclInputVar("elementwise_add-X", nvinfer1::DimsCHW(10, 3, 3));
2828
validator.DeclParamVar("elementwise_add-Y", nvinfer1::Dims3(10, 1, 1));
2929
// validator.DeclParamVar("mul-Y", nvinfer1::Dims2(8, 2));
@@ -41,13 +41,13 @@ TEST(elementwise_op, add_weight_test) {
4141

4242
validator.SetOp(*desc.Proto());
4343

44-
validator.Execute(1);
44+
validator.Execute(8);
4545
}
4646

4747
TEST(elementwise_op, add_tensor_test) {
4848
std::unordered_set<std::string> parameters;
4949
framework::Scope scope;
50-
TRTConvertValidation validator(2, parameters, scope, 1 << 15);
50+
TRTConvertValidation validator(8, parameters, scope, 1 << 15);
5151
validator.DeclInputVar("elementwise_add-X", nvinfer1::DimsCHW(10, 3, 3));
5252
validator.DeclInputVar("elementwise_add-Y", nvinfer1::Dims3(10, 3, 3));
5353
// validator.DeclParamVar("mul-Y", nvinfer1::Dims2(8, 2));
@@ -64,7 +64,7 @@ TEST(elementwise_op, add_tensor_test) {
6464

6565
validator.SetOp(*desc.Proto());
6666

67-
validator.Execute(1);
67+
validator.Execute(8);
6868
}
6969

7070
} // namespace tensorrt

paddle/fluid/inference/tensorrt/convert/ut_helper.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ class TRTConvertValidation {
149149
cudaStreamSynchronize(*engine_->stream());
150150

151151
ASSERT_FALSE(op_desc_->OutputArgumentNames().empty());
152-
const size_t output_space_size = 2000;
152+
const size_t output_space_size = 3000;
153153
for (const auto& output : op_desc_->OutputArgumentNames()) {
154154
std::vector<float> fluid_out;
155155
std::vector<float> trt_out(output_space_size);

0 commit comments

Comments
 (0)