Skip to content

Commit 6f6f330

Browse files
committed
update the register method
1 parent 326221a commit 6f6f330

File tree

8 files changed

+115
-118
lines changed

8 files changed

+115
-118
lines changed
Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
11
nv_test(test_tensorrt SRCS test_tensorrt.cc DEPS dynload_cuda device_context dynamic_loader)
22
nv_test(test_tensorrt_engine SRCS test_engine.cc engine.cc DEPS dynload_cuda)
3-
cc_library(tensorrt DEPS tensorrt_convert)
43
add_subdirectory(convert)
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
nv_library(tensorrt_convert SRCS convert.cc mul_op.cc conv2d_op.cc DEPS dynload_cuda)
2-
nv_test(test_tensorrt_convert SRCS test_convert.cc DEPS tensorrt paddle_fluid)
1+
file(GLOB TENSORRT_OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "*_op.cc")
2+
nv_test(test_tensorrt_op_converter SRCS test_op_converter.cc ${TENSORRT_OPS} DEPS ${FLUID_CORE_MODULES})

paddle/fluid/inference/tensorrt/convert/conv2d_op.cc

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,25 +4,30 @@ Licensed under the Apache License, Version 2.0 (the "License");
44
you may not use this file except in compliance with the License.
55
You may obtain a copy of the License at
66
7-
http://www.apache.org/licenses/LICENSE-2.0
7+
http://www.apache.org/licenses/LICENSE-2.0
88
99
Unless required by applicable law or agreed to in writing, software
1010
distributed under the License is distributed on an "AS IS" BASIS,
1111
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

15-
#include "paddle/fluid/inference/tensorrt/convert/convert.h"
15+
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
1616

1717
namespace paddle {
1818
namespace inference {
1919
namespace tensorrt {
2020

21-
REGISTER_TRT_OP_CONVETER(conv2d, Conv2dOpConverter);
21+
class Conv2dOpConverter : public OpConverter {
22+
public:
23+
Conv2dOpConverter() {}
24+
void operator()(const framework::OpDesc& op) override {
25+
LOG(INFO)
26+
<< "convert a fluid conv2d op to tensorrt conv layer without bias";
27+
}
28+
};
2229

23-
void Conv2dOpConverter::Convert(const framework::OpDesc& op) {
24-
LOG(INFO) << "convert a fluid conv2d op to tensorrt conv layer without bias";
25-
}
30+
REGISTER_TRT_OP_CONVERTER(conv2d, Conv2dOpConverter);
2631

2732
} // namespace tensorrt
2833
} // namespace inference

paddle/fluid/inference/tensorrt/convert/convert.cc

Lines changed: 0 additions & 31 deletions
This file was deleted.

paddle/fluid/inference/tensorrt/convert/convert.h

Lines changed: 0 additions & 69 deletions
This file was deleted.

paddle/fluid/inference/tensorrt/convert/mul_op.cc

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,25 +4,29 @@ Licensed under the Apache License, Version 2.0 (the "License");
44
you may not use this file except in compliance with the License.
55
You may obtain a copy of the License at
66
7-
http://www.apache.org/licenses/LICENSE-2.0
7+
http://www.apache.org/licenses/LICENSE-2.0
88
99
Unless required by applicable law or agreed to in writing, software
1010
distributed under the License is distributed on an "AS IS" BASIS,
1111
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

15-
#include "paddle/fluid/inference/tensorrt/convert/convert.h"
15+
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
1616

1717
namespace paddle {
1818
namespace inference {
1919
namespace tensorrt {
2020

21-
REGISTER_TRT_OP_CONVETER(mul, MulOpConverter);
21+
class MulOpConverter : public OpConverter {
22+
public:
23+
MulOpConverter() {}
24+
void operator()(const framework::OpDesc& op) override {
25+
LOG(INFO) << "convert a fluid mul op to tensorrt fc layer without bias";
26+
}
27+
};
2228

23-
void MulOpConverter::Convert(const framework::OpDesc& op) {
24-
LOG(INFO) << "convert a fluid mul op to tensorrt fc layer without bias";
25-
}
29+
REGISTER_TRT_OP_CONVERTER(mul, MulOpConverter);
2630

2731
} // namespace tensorrt
2832
} // namespace inference
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
17+
#include <string>
18+
#include <unordered_map>
19+
#include "paddle/fluid/framework/block_desc.h"
20+
#include "paddle/fluid/framework/scope.h"
21+
#include "paddle/fluid/inference/tensorrt/engine.h"
22+
23+
namespace paddle {
24+
namespace inference {
25+
namespace tensorrt {
26+
27+
/*
28+
* Convert Op from Fluid to TensorRT Engine.
29+
*/
30+
class OpConverter {
31+
public:
32+
OpConverter() {}
33+
34+
virtual void operator()(const framework::OpDesc& op) {}
35+
void Execute(const framework::OpDesc& op) {
36+
std::string type = op.Type();
37+
auto it = converters_.find(type);
38+
PADDLE_ENFORCE(it != converters_.end(), "no OpConverter for optype [%s]",
39+
type);
40+
(*it->second)(op);
41+
}
42+
43+
static OpConverter& Global() {
44+
static auto* x = new OpConverter;
45+
return *x;
46+
}
47+
48+
template <typename T>
49+
void Register(const std::string& key) {
50+
converters_[key] = new T;
51+
}
52+
53+
virtual ~OpConverter() {}
54+
55+
private:
56+
// registered op converter map, whose key is the fluid op type, and value is
57+
// the pointer position of corresponding OpConverter class.
58+
std::unordered_map<std::string, OpConverter*> converters_;
59+
60+
// fluid inference scope
61+
framework::Scope* scope_;
62+
// tensorrt input/output tensor map, whose key is the fluid variable name,
63+
// and value is the pointer position of tensorrt tensor
64+
std::unordered_map<std::string, nvinfer1::ITensor*> tr_tensors_;
65+
};
66+
67+
#define REGISTER_TRT_OP_CONVERTER(op_type__, Converter__) \
68+
struct trt_##op_type__##_converter { \
69+
trt_##op_type__##_converter() { \
70+
OpConverter::Global().Register<Converter__>(#op_type__); \
71+
} \
72+
}; \
73+
trt_##op_type__##_converter trt_##op_type__##_converter__;
74+
75+
class BlockConverter {
76+
public:
77+
BlockConverter() {}
78+
79+
// convert fluid block to tensorrt network
80+
void ConvertBlock(const framework::BlockDesc& block) {
81+
for (auto op : block.AllOps()) {
82+
OpConverter::Global().Execute(*op);
83+
}
84+
}
85+
};
86+
87+
} // namespace tensorrt
88+
} // namespace inference
89+
} // namespace paddle

paddle/fluid/inference/tensorrt/convert/test_convert.cc renamed to paddle/fluid/inference/tensorrt/convert/test_op_converter.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,21 +14,21 @@ limitations under the License. */
1414

1515
#include <gtest/gtest.h>
1616
#include "paddle/fluid/framework/program_desc.h"
17-
#include "paddle/fluid/inference/tensorrt/convert/convert.h"
17+
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
1818

1919
namespace paddle {
2020
namespace inference {
2121
namespace tensorrt {
2222

23-
TEST(tensorrt, ConvertBlock) {
23+
TEST(BlockConverter, ConvertBlock) {
2424
framework::ProgramDesc prog;
2525
auto* block = prog.MutableBlock(0);
2626
auto* mul_op = block->AppendOp();
2727
mul_op->SetType("mul");
2828
auto* conv2d_op = block->AppendOp();
2929
conv2d_op->SetType("conv2d");
3030

31-
TensorRTConverter converter;
31+
BlockConverter converter;
3232
converter.ConvertBlock(*block);
3333
}
3434

0 commit comments

Comments
 (0)