Skip to content

Commit c4e3010

Browse files
committed
use template to do registry
1 parent d599de5 commit c4e3010

File tree

5 files changed

+26
-47
lines changed

5 files changed

+26
-47
lines changed
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
nv_library(tensorrt_convert SRCS convert.cc DEPS dynload_cuda)
1+
nv_library(tensorrt_convert SRCS convert.cc mul_op.cc conv2d_op.cc DEPS dynload_cuda)
22
nv_test(test_tensorrt_convert SRCS test_convert.cc DEPS tensorrt paddle_fluid)

paddle/fluid/inference/tensorrt/convert/convert_conv2d.h renamed to paddle/fluid/inference/tensorrt/convert/conv2d_op.cc

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,25 +12,18 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

15-
#pragma once
1615
#include "paddle/fluid/inference/tensorrt/convert/convert.h"
1716

1817
namespace paddle {
1918
namespace inference {
2019
namespace tensorrt {
2120

22-
class Conv2dOpConverter : public OpConverter {
23-
public:
24-
Conv2dOpConverter() {}
25-
void Convert(const framework::OpDesc& op);
26-
};
21+
REGISTER_TRT_OP_CONVETER(conv2d, Conv2dOpConverter);
2722

2823
void Conv2dOpConverter::Convert(const framework::OpDesc& op) {
2924
LOG(INFO) << "convert a fluid conv2d op to tensorrt conv layer without bias";
3025
}
3126

32-
REGISTER_TRT_OP_CONVETER(conv2d, Conv2dOpConverter);
33-
3427
} // namespace tensorrt
3528
} // namespace inference
3629
} // namespace paddle

paddle/fluid/inference/tensorrt/convert/convert.cc

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,6 @@ See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

1515
#include "paddle/fluid/inference/tensorrt/convert/convert.h"
16-
#include "paddle/fluid/inference/tensorrt/convert/convert_conv2d.h"
17-
#include "paddle/fluid/inference/tensorrt/convert/convert_mul.h"
1816

1917
namespace paddle {
2018
namespace inference {
@@ -23,10 +21,8 @@ namespace tensorrt {
2321
void TensorRTConverter::ConvertBlock(const framework::BlockDesc& block) {
2422
for (auto op : block.AllOps()) {
2523
std::string type = op->Type();
26-
PADDLE_ENFORCE(GetOpConverter().count(type),
27-
"No converter registered for op: %s", type);
28-
auto op_converter = GetOpConverter()[type];
29-
op_converter->Convert(*op);
24+
OpConverter op_converter;
25+
op_converter.Convert(*op);
3026
}
3127
}
3228

paddle/fluid/inference/tensorrt/convert/convert.h

Lines changed: 21 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,21 @@ namespace paddle {
2626
namespace inference {
2727
namespace tensorrt {
2828

29-
class ConverterBase {
29+
class OpConverter {
3030
public:
31-
ConverterBase() {}
31+
OpConverter() {}
32+
33+
void Convert(const framework::OpDesc& op) {
34+
std::string type = op.Type();
35+
OpConverter& op_converter = this->register_op_converter_[type];
36+
op_converter.Convert(op);
37+
}
38+
39+
template <typename T>
40+
static void Register(const std::string key) {
41+
register_op_converter_[key] = T();
42+
}
43+
static std::unordered_map<std::string, OpConverter> register_op_converter_;
3244

3345
// fluid inference scope
3446
framework::Scope* scope_;
@@ -37,30 +49,14 @@ class ConverterBase {
3749
std::unordered_map<std::string, nvinfer1::ITensor*> tr_tensors_;
3850
};
3951

40-
class OpConverter : public ConverterBase {
41-
public:
42-
OpConverter() {}
43-
virtual ~OpConverter() {}
44-
45-
// convert fluid op to tensorrt layer
46-
virtual void Convert(const framework::OpDesc& op) = 0;
47-
};
48-
49-
static std::unordered_map<std::string, OpConverter*>& GetOpConverter() {
50-
static std::unordered_map<std::string, OpConverter*> register_op_converter;
51-
return register_op_converter;
52-
}
53-
54-
#define REGISTER_TRT_OP_CONVETER(op_type, convert_class) \
55-
class convert_class##Register { \
56-
public: \
57-
convert_class##Register() { \
58-
GetOpConverter()[#op_type] = new convert_class; \
59-
} \
60-
}; \
61-
convert_class##Register convert_class##reg;
52+
#define REGISTER_TRT_OP_CONVETER(op_type, convert_class) \
53+
class convert_class : public OpConverter { \
54+
public: \
55+
convert_class() { OpConverter::Register<convert_class>(#op_type); } \
56+
void Convert(const framework::OpDesc& op); \
57+
}
6258

63-
class TensorRTConverter : public ConverterBase {
59+
class TensorRTConverter {
6460
public:
6561
TensorRTConverter() {}
6662

paddle/fluid/inference/tensorrt/convert/convert_mul.h renamed to paddle/fluid/inference/tensorrt/convert/mul_op.cc

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,20 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

15-
#pragma once
1615
#include "paddle/fluid/inference/tensorrt/convert/convert.h"
1716

1817
namespace paddle {
1918
namespace inference {
2019
namespace tensorrt {
2120

22-
class MulOpConverter : public OpConverter {
23-
public:
24-
MulOpConverter() {}
25-
void Convert(const framework::OpDesc& op);
26-
};
27-
2821
REGISTER_TRT_OP_CONVETER(mul, MulOpConverter);
22+
2923
void MulOpConverter::Convert(const framework::OpDesc& op) {
3024
LOG(INFO) << "convert a fluid mul op to tensorrt fc layer without bias";
3125
}

0 commit comments

Comments
 (0)