Skip to content

Commit 2d7134b

Browse files
committed
add initial code for plugin
1 parent 0b38822 commit 2d7134b

14 files changed

+653
-2
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
nv_library(tensorrt_engine SRCS engine.cc DEPS framework_proto device_context)
22
nv_test(test_tensorrt SRCS test_tensorrt.cc DEPS dynload_cuda device_context dynamic_loader)
33
nv_test(test_tensorrt_engine SRCS test_engine.cc DEPS dynload_cuda tensorrt_engine)
4+
add_subdirectory(plugin)
45
add_subdirectory(convert)

paddle/fluid/inference/tensorrt/convert/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
nv_library(tensorrt_converter
33
SRCS mul_op.cc conv2d_op.cc fc_op.cc pool2d_op.cc elementwise_op.cc
44
batch_norm_op.cc activation_op.cc softmax_op.cc concat_op.cc dropout_op.cc pad_op.cc
5-
DEPS tensorrt_engine operator scope framework_proto op_registry)
5+
DEPS tensorrt_engine tensorrt_plugin operator scope framework_proto op_registry)
66

77
nv_test(test_op_converter SRCS test_op_converter.cc DEPS
88
${FLUID_CORE_MODULES} tensorrt_engine tensorrt_converter)

paddle/fluid/inference/tensorrt/convert/concat_op.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ namespace inference {
1919
namespace tensorrt {
2020

2121
/*
22-
* MulOp, IMatrixMultiplyLayer in TRT. This Layer doesn't has weights.
22+
* ConcatOp
2323
*/
2424
class ConcatOpConverter : public OpConverter {
2525
public:
Binary file not shown.
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
nv_library(tensorrt_plugin SRCS plugin_factory.cc plugin_utils.cc
2+
trt_plugin.cc split_op_plugin.cu DEPS enforce)
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/fluid/inference/tensorrt/plugin/plugin_factory.h"
16+
17+
namespace paddle {
18+
namespace inference {
19+
namespace tensorrt {
20+
21+
PluginTensorRT* PluginFactoryTensorRT::createPlugin(const char* layer_name,
22+
const void* serial_data,
23+
size_t serial_length) {
24+
size_t parsed_byte = 0;
25+
std::string encoded_op_name =
26+
ExtractOpName(serial_data, serial_length, &parsed_byte);
27+
28+
if (!IsPlugin(encoded_op_name)) {
29+
return nullptr;
30+
}
31+
32+
auto plugin_ptr =
33+
plugin_registry_[encoded_op_name].first(serial_data, serial_length);
34+
owned_plugins_.emplace_back(plugin_ptr);
35+
36+
return plugin_ptr;
37+
}
38+
39+
PluginTensorRT* PluginFactoryTensorRT::CreatePlugin(
40+
const std::string& op_name) {
41+
if (!IsPlugin(op_name)) return nullptr;
42+
43+
auto plugin_ptr = plugin_registry_[op_name].second();
44+
owned_plugins_.emplace_back(plugin_ptr);
45+
46+
return plugin_ptr;
47+
}
48+
49+
bool PluginFactoryTensorRT::RegisterPlugin(
50+
const std::string& op_name, PluginDeserializeFunc deserialize_func,
51+
PluginConstructFunc construct_func) {
52+
if (IsPlugin(op_name)) return false;
53+
54+
auto ret = plugin_registry_.emplace(
55+
op_name, std::make_pair(deserialize_func, construct_func));
56+
57+
return ret.second;
58+
}
59+
60+
void PluginFactoryTensorRT::DestroyPlugins() { owned_plugins_.clear(); }
61+
62+
} // namespace tensorrt
63+
} // namespace inference
64+
} // namespace paddle
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
17+
#include <memory>
18+
#include <unordered_map>
19+
20+
#include "NvInfer.h"
21+
#include "paddle/fluid/inference/tensorrt/plugin/plugin_utils.h"
22+
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
23+
#include "paddle/fluid/platform/enforce.h"
24+
25+
namespace paddle {
26+
namespace inference {
27+
namespace tensorrt {
28+
29+
class PluginFactoryTensorRT : public nvinfer1::IPluginFactory {
30+
public:
31+
static PluginFactoryTensorRT* GetInstance() {
32+
static PluginFactoryTensorRT* factory_instance =
33+
new PluginFactoryTensorRT();
34+
return factory_instance;
35+
}
36+
37+
// Deserialization method
38+
PluginTensorRT* createPlugin(const char* layer_name, const void* serial_data,
39+
size_t serial_length) override;
40+
41+
// Plugin construction, PluginFactoryTensorRT owns the plugin.
42+
PluginTensorRT* CreatePlugin(const std::string& op_name);
43+
44+
bool RegisterPlugin(const std::string& op_name,
45+
PluginDeserializeFunc deserialize_func,
46+
PluginConstructFunc construct_func);
47+
48+
bool IsPlugin(const std::string& op_name) {
49+
return plugin_registry_.find(op_name) != plugin_registry_.end();
50+
}
51+
52+
size_t CountOwnedPlugins() { return owned_plugins_.size(); }
53+
54+
void DestroyPlugins();
55+
56+
protected:
57+
std::unordered_map<std::string,
58+
std::pair<PluginDeserializeFunc, PluginConstructFunc>>
59+
plugin_registry_;
60+
std::vector<std::unique_ptr<PluginTensorRT>> owned_plugins_;
61+
};
62+
63+
class TrtPluginRegistrar {
64+
public:
65+
TrtPluginRegistrar(const std::string& name,
66+
PluginDeserializeFunc deserialize_func,
67+
PluginConstructFunc construct_func) {
68+
auto factory = PluginFactoryTensorRT::GetInstance();
69+
// platform::PADDLE_ENFORCE(factory->RegisterPlugin(name, deserialize_func,
70+
// construct_func), "Falied to register plugin [%s]", name);
71+
// platform::PADDLE_ENFORCE(factory->RegisterPlugin(name, deserialize_func,
72+
// construct_func));
73+
factory->RegisterPlugin(name, deserialize_func, construct_func);
74+
}
75+
};
76+
77+
#define REGISTER_TRT_PLUGIN(name, deserialize_func, construct_func) \
78+
REGISTER_TRT_PLUGIN_UNIQ_HELPER(__COUNTER__, name, deserialize_func, \
79+
construct_func)
80+
#define REGISTER_TRT_PLUGIN_UNIQ_HELPER(ctr, name, deserialize_func, \
81+
construct_func) \
82+
REGISTER_TRT_PLUGIN_UNIQ(ctr, name, deserialize_func, construct_func)
83+
#define REGISTER_TRT_PLUGIN_UNIQ(ctr, name, deserialize_func, construct_func) \
84+
static ::paddle::inference::tensorrt::TrtPluginRegistrar \
85+
trt_plugin_registrar##ctr __attribute__((unused)) = \
86+
::paddle::inference::tensorrt::TrtPluginRegistrar( \
87+
name, deserialize_func, construct_func)
88+
89+
} // namespace tensorrt
90+
} // namespace inference
91+
} // namespace paddle
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/fluid/inference/tensorrt/plugin/plugin_utils.h"
16+
#include <cassert>
17+
18+
namespace paddle {
19+
namespace inference {
20+
namespace tensorrt {
21+
22+
std::string ExtractOpName(const void* serial_data, size_t serial_length,
23+
size_t* incremental) {
24+
size_t op_name_char_count = *static_cast<const size_t*>(serial_data);
25+
*incremental = sizeof(size_t) + op_name_char_count;
26+
27+
assert(serial_length >= *incremental);
28+
29+
const char* buffer = static_cast<const char*>(serial_data) + sizeof(size_t);
30+
std::string op_name(buffer, op_name_char_count);
31+
32+
return op_name;
33+
}
34+
35+
} // namespace tensorrt
36+
} // namespace inference
37+
} // namespace paddle
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
#include <functional>
17+
18+
#include "NvInfer.h"
19+
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
20+
21+
namespace paddle {
22+
namespace inference {
23+
namespace tensorrt {
24+
25+
typedef std::function<PluginTensorRT*(const void*, size_t)>
26+
PluginDeserializeFunc;
27+
typedef std::function<PluginTensorRT*(void)> PluginConstructFunc;
28+
29+
std::string ExtractOpName(const void* serial_data, size_t serial_length,
30+
size_t* incremental);
31+
32+
} // namespace tensorrt
33+
} // namespace inference
34+
} // namespze paddle
Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
17+
#include <cassert>
18+
#include <cstring>
19+
#include <type_traits>
20+
#include <vector>
21+
22+
template <typename T>
23+
inline void serialize_value(void** buffer, T const& value);
24+
25+
template <typename T>
26+
inline void deserialize_value(void const** buffer, size_t* buffer_size,
27+
T* value);
28+
29+
namespace {
30+
31+
template <typename T, class Enable = void>
32+
struct Serializer {};
33+
34+
template <typename T>
35+
struct Serializer<T, typename std::enable_if<std::is_arithmetic<T>::value ||
36+
std::is_enum<T>::value ||
37+
std::is_pod<T>::value>::type> {
38+
static size_t serialized_size(T const& value) { return sizeof(T); }
39+
static void serialize(void** buffer, T const& value) {
40+
::memcpy(*buffer, &value, sizeof(T));
41+
reinterpret_cast<char*&>(*buffer) += sizeof(T);
42+
}
43+
static void deserialize(void const** buffer, size_t* buffer_size, T* value) {
44+
assert(*buffer_size >= sizeof(T));
45+
::memcpy(value, *buffer, sizeof(T));
46+
reinterpret_cast<char const*&>(*buffer) += sizeof(T);
47+
*buffer_size -= sizeof(T);
48+
}
49+
};
50+
51+
template <>
52+
struct Serializer<const char*> {
53+
static size_t serialized_size(const char* value) { return strlen(value) + 1; }
54+
static void serialize(void** buffer, const char* value) {
55+
::strcpy(static_cast<char*>(*buffer), value);
56+
reinterpret_cast<char*&>(*buffer) += strlen(value) + 1;
57+
}
58+
static void deserialize(void const** buffer, size_t* buffer_size,
59+
const char** value) {
60+
*value = static_cast<char const*>(*buffer);
61+
size_t data_size = strnlen(*value, *buffer_size) + 1;
62+
assert(*buffer_size >= data_size);
63+
reinterpret_cast<char const*&>(*buffer) += data_size;
64+
*buffer_size -= data_size;
65+
}
66+
};
67+
68+
template <typename T>
69+
struct Serializer<std::vector<T>,
70+
typename std::enable_if<std::is_arithmetic<T>::value ||
71+
std::is_enum<T>::value ||
72+
std::is_pod<T>::value>::type> {
73+
static size_t serialized_size(std::vector<T> const& value) {
74+
return sizeof(value.size()) + value.size() * sizeof(T);
75+
}
76+
static void serialize(void** buffer, std::vector<T> const& value) {
77+
serialize_value(buffer, value.size());
78+
size_t nbyte = value.size() * sizeof(T);
79+
::memcpy(*buffer, value.data(), nbyte);
80+
reinterpret_cast<char*&>(*buffer) += nbyte;
81+
}
82+
static void deserialize(void const** buffer, size_t* buffer_size,
83+
std::vector<T>* value) {
84+
size_t size;
85+
deserialize_value(buffer, buffer_size, &size);
86+
value->resize(size);
87+
size_t nbyte = value->size() * sizeof(T);
88+
assert(*buffer_size >= nbyte);
89+
::memcpy(value->data(), *buffer, nbyte);
90+
reinterpret_cast<char const*&>(*buffer) += nbyte;
91+
*buffer_size -= nbyte;
92+
}
93+
};
94+
95+
} // namespace
96+
97+
template <typename T>
98+
inline size_t serialized_size(T const& value) {
99+
return Serializer<T>::serialized_size(value);
100+
}
101+
102+
template <typename T>
103+
inline void serialize_value(void** buffer, T const& value) {
104+
return Serializer<T>::serialize(buffer, value);
105+
}
106+
107+
template <typename T>
108+
inline void deserialize_value(void const** buffer, size_t* buffer_size,
109+
T* value) {
110+
return Serializer<T>::deserialize(buffer, buffer_size, value);
111+
}

0 commit comments

Comments
 (0)