Skip to content

Commit 4d49e61

Browse files
committed
fix comments
1 parent bcd67bd commit 4d49e61

File tree

6 files changed

+58
-9
lines changed

6 files changed

+58
-9
lines changed

paddle/fluid/inference/tensorrt/convert/CMakeLists.txt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
# Add TRT tests
22
nv_library(tensorrt_converter
3-
SRCS conv2d_op.cc fc_op.cc
3+
SRCS mul_op.cc conv2d_op.cc fc_op.cc
44
DEPS tensorrt_engine mul_op)
55

66
nv_test(test_op_converter SRCS test_op_converter.cc DEPS
77
${FLUID_CORE_MODULES} tensorrt_engine tensorrt_converter)
88

99
nv_test(test_io_converter SRCS test_io_converter.cc io_converter.cc DEPS dynload_cuda dynamic_loader lod_tensor)
10+
nv_test(test_trt_mul_op SRCS test_mul_op.cc mul_op.cc
11+
DEPS ${FLUID_CORE_MODULES} tensorrt_engine mul_op SERIAL)
1012
nv_test(test_trt_fc_op SRCS test_fc_op.cc fc_op.cc
1113
DEPS ${FLUID_CORE_MODULES} tensorrt_engine mul_op SERIAL)
1214
nv_test(test_trt_activation_op SRCS test_activation_op.cc activation_op.cc

paddle/fluid/inference/tensorrt/convert/fc_op.cc

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -116,5 +116,4 @@ class FcOpConverter : public OpConverter {
116116
} // namespace inference
117117
} // namespace paddle
118118

119-
REGISTER_TRT_OP_CONVERTER(mul, FcOpConverter);
120-
USE_OP(mul);
119+
REGISTER_TRT_OP_CONVERTER(fc, FcOpConverter);

paddle/fluid/inference/tensorrt/convert/mul_op.cc

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,5 +50,4 @@ class MulOpConverter : public OpConverter {
5050
} // namespace paddle
5151

5252
USE_OP(mul);
53-
// TODO(xingzhaolong): change the name to mul then
54-
REGISTER_TRT_OP_CONVERTER(mul_temp, MulOpConverter);
53+
REGISTER_TRT_OP_CONVERTER(mul, MulOpConverter);

paddle/fluid/inference/tensorrt/convert/test_fc_op.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ TEST(fc_op, test) {
2424
std::unordered_set<std::string> parameters({"mul-Y"});
2525
framework::Scope scope;
2626
TRTConvertValidation validator(10, parameters, scope, 1000);
27-
2827
validator.DeclInputVar("mul-X", nvinfer1::Dims4(1, 10, 1, 1));
2928
validator.DeclParamVar("mul-Y", nvinfer1::Dims2(10, 2));
3029
// validator.DeclParamVar("mul-Y", nvinfer1::Dims2(8, 2));
@@ -45,3 +44,4 @@ TEST(fc_op, test) {
4544
} // namespace tensorrt
4645
} // namespace inference
4746
} // namespace paddle
47+
USE_OP(mul);
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include <gtest/gtest.h>
16+
#include "paddle/fluid/framework/op_registry.h"
17+
#include "paddle/fluid/inference/tensorrt/convert/ut_helper.h"
18+
19+
namespace paddle {
20+
namespace inference {
21+
namespace tensorrt {
22+
23+
TEST(MulOpConverter, main) {
24+
framework::Scope scope;
25+
std::unordered_set<std::string> parameters;
26+
TRTConvertValidation validator(10, parameters, scope, 1000);
27+
validator.DeclInputVar("mul-X", nvinfer1::Dims2(10, 6));
28+
validator.DeclInputVar("mul-Y", nvinfer1::Dims2(6, 10));
29+
validator.DeclOutputVar("mul-Out", nvinfer1::Dims2(10, 10));
30+
31+
// Prepare Op description
32+
framework::OpDesc desc;
33+
desc.SetType("mul");
34+
desc.SetInput("X", {"mul-X"});
35+
desc.SetInput("Y", {"mul-Y"});
36+
desc.SetOutput("Out", {"mul-Out"});
37+
38+
LOG(INFO) << "set OP";
39+
validator.SetOp(*desc.Proto());
40+
LOG(INFO) << "execute";
41+
42+
validator.Execute(1);
43+
}
44+
45+
} // namespace tensorrt
46+
} // namespace inference
47+
} // namespace paddle
48+
49+
USE_OP(mul);

paddle/fluid/operators/tensorrt_engine_op_test.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,14 +66,14 @@ TEST(TensorRTEngineOp, manual) {
6666
framework::BlockDesc block_desc(&program, block_);
6767
LOG(INFO) << "create fc op";
6868
auto* fc0 = block_desc.AppendOp();
69-
fc0->SetType("mul");
69+
fc0->SetType("fc");
7070
fc0->SetInput("X", std::vector<std::string>({"x"})); // 4 x 1 x 1
7171
fc0->SetInput("Y", std::vector<std::string>({"y"})); // 4 x 6
7272
fc0->SetOutput("Out", std::vector<std::string>({"z"})); // 6 x 1 x 1
7373

7474
LOG(INFO) << "create fc op";
7575
auto* fc1 = block_desc.AppendOp();
76-
fc1->SetType("mul");
76+
fc1->SetType("fc");
7777
fc1->SetInput("X", std::vector<std::string>({"z"}));
7878
fc1->SetInput("Y", std::vector<std::string>({"y0"})); // 6 x 8
7979
fc1->SetOutput("Out", std::vector<std::string>({"z0"})); // 8 x 1 x 1
@@ -208,4 +208,4 @@ TEST(TensorRTEngineOp, fc) { Execute(40, 28, 28); }
208208
} // namespace operators
209209
} // namespace paddle
210210

211-
USE_TRT_CONVERTER(mul)
211+
USE_TRT_CONVERTER(fc)

0 commit comments

Comments
 (0)