Skip to content

Commit 7382f98

Browse files
committed
1. set ut batch > 1 2. readd the mul op(utest will be added later)
1 parent 4e377f8 commit 7382f98

File tree

3 files changed

+56
-2
lines changed

3 files changed

+56
-2
lines changed
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
16+
17+
namespace paddle {
18+
namespace inference {
19+
namespace tensorrt {
20+
21+
/*
22+
* MulOp, IMatrixMultiplyLayer in TRT. This Layer doesn't has weights.
23+
*/
24+
class MulOpConverter : public OpConverter {
25+
public:
26+
void operator()(const framework::proto::OpDesc& op,
27+
const framework::Scope& scope, bool test_mode) override {
28+
VLOG(4) << "convert a fluid mul op to tensorrt mul layer without bias";
29+
30+
framework::OpDesc op_desc(op, nullptr);
31+
// Declare inputs
32+
auto* input1 = engine_->GetITensor(op_desc.Input("X")[0]);
33+
auto* input2 = engine_->GetITensor(op_desc.Input("Y")[0]);
34+
// Both the input1 and input2 do not need transpose.
35+
auto* layer = TRT_ENGINE_ADD_LAYER(
36+
engine_, MatrixMultiply, *const_cast<nvinfer1::ITensor*>(input1), false,
37+
*const_cast<nvinfer1::ITensor*>(input2), false);
38+
39+
auto output_name = op_desc.Output("Out")[0];
40+
engine_->SetITensor(output_name, layer->getOutput(0));
41+
if (test_mode) { // the test framework can not determine which is the
42+
// output, so place the declaration inside.
43+
engine_->DeclareOutput(output_name);
44+
}
45+
}
46+
};
47+
48+
} // namespace tensorrt
49+
} // namespace inference
50+
} // namespace paddle
51+
52+
USE_OP(mul);
53+
// TODO(xingzhaolong): change the name to mul then
54+
REGISTER_TRT_OP_CONVERTER(mul_temp, MulOpConverter);

paddle/fluid/inference/tensorrt/convert/test_activation_op.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ namespace tensorrt {
2323
TEST(ReluOpConverter, main) {
2424
framework::Scope scope;
2525
std::unordered_set<std::string> parameters;
26-
TRTConvertValidation validator(1, parameters, scope, 1000);
26+
TRTConvertValidation validator(10, parameters, scope, 1000);
2727
validator.DeclInputVar("relu-X", nvinfer1::Dims2(10, 6));
2828
validator.DeclOutputVar("relu-Out", nvinfer1::Dims2(10, 6));
2929

paddle/fluid/inference/tensorrt/convert/test_fc_op.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ namespace tensorrt {
2323
TEST(fc_op, test) {
2424
std::unordered_set<std::string> parameters({"mul-Y"});
2525
framework::Scope scope;
26-
TRTConvertValidation validator(1, parameters, scope, 1000);
26+
TRTConvertValidation validator(10, parameters, scope, 1000);
2727

2828
validator.DeclInputVar("mul-X", nvinfer1::Dims4(1, 10, 1, 1));
2929
validator.DeclParamVar("mul-Y", nvinfer1::Dims2(10, 2));

0 commit comments

Comments
 (0)