Skip to content

Commit 0b96268

Browse files
committed
fix comments
test=develop
1 parent e5bf861 commit 0b96268

File tree

8 files changed

+64
-60
lines changed

8 files changed

+64
-60
lines changed

paddle/fluid/inference/tensorrt/convert/split_op.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ class SplitOpConverter : public OpConverter {
3535
int input_num = op_desc.Input("X").size();
3636
size_t output_num = op_desc.Output("Out").size();
3737

38+
// Get Attrs
3839
PADDLE_ENFORCE(input_num == 1);
3940
int axis = boost::get<int>(op_desc.GetAttr("axis"));
4041
std::vector<int> output_lengths =
@@ -48,9 +49,10 @@ class SplitOpConverter : public OpConverter {
4849

4950
PADDLE_ENFORCE(output_lengths.size() == output_num);
5051

52+
//
5153
SplitPlugin* plugin = new SplitPlugin(axis, output_lengths);
5254
nvinfer1::IPluginLayer* layer =
53-
engine_->addPlugin(&input, input_num, plugin);
55+
engine_->AddPlugin(&input, input_num, plugin);
5456

5557
std::string layer_name = "split (Output: ";
5658
for (size_t i = 0; i < output_num; i++) {

paddle/fluid/inference/tensorrt/engine.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -254,7 +254,7 @@ void TensorRTEngine::freshDeviceId() {
254254
cudaSetDevice(device_);
255255
}
256256

257-
nvinfer1::IPluginLayer *TensorRTEngine::addPlugin(
257+
nvinfer1::IPluginLayer *TensorRTEngine::AddPlugin(
258258
nvinfer1::ITensor *const *inputs, int nbInputs, PluginTensorRT *plugin) {
259259
owned_plugin_.emplace_back(plugin);
260260
return infer_network_.get()->addPluginExt(inputs, nbInputs, *plugin);

paddle/fluid/inference/tensorrt/engine.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ class TensorRTEngine : public EngineBase {
126126
void SetRuntimeBatch(size_t batch_size);
127127
int GetRuntimeBatch();
128128
int GetDevice() { return device_; }
129-
nvinfer1::IPluginLayer* addPlugin(nvinfer1::ITensor* const* inputs,
129+
nvinfer1::IPluginLayer* AddPlugin(nvinfer1::ITensor* const* inputs,
130130
int nbInputs, PluginTensorRT*);
131131

132132
// A pointer to CPU memory is needed of the TRT weight.

paddle/fluid/inference/tensorrt/plugin/serialize.h

Lines changed: 26 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,11 @@
2020
#include <vector>
2121

2222
template <typename T>
23-
inline void serialize_value(void** buffer, T const& value);
23+
inline void SerializeValue(void** buffer, T const& value);
2424

2525
template <typename T>
26-
inline void deserialize_value(void const** buffer, size_t* buffer_size,
27-
T* value);
26+
inline void DeserializeValue(void const** buffer, size_t* buffer_size,
27+
T* value);
2828

2929
namespace {
3030

@@ -35,27 +35,27 @@ template <typename T>
3535
struct Serializer<T, typename std::enable_if<std::is_arithmetic<T>::value ||
3636
std::is_enum<T>::value ||
3737
std::is_pod<T>::value>::type> {
38-
static size_t serialized_size(T const& value) { return sizeof(T); }
39-
static void serialize(void** buffer, T const& value) {
40-
::memcpy(*buffer, &value, sizeof(T));
38+
static size_t SerializedSize(T const& value) { return sizeof(T); }
39+
static void Serialize(void** buffer, T const& value) {
40+
std::memcpy(*buffer, &value, sizeof(T));
4141
reinterpret_cast<char*&>(*buffer) += sizeof(T);
4242
}
43-
static void deserialize(void const** buffer, size_t* buffer_size, T* value) {
43+
static void Deserialize(void const** buffer, size_t* buffer_size, T* value) {
4444
assert(*buffer_size >= sizeof(T));
45-
::memcpy(value, *buffer, sizeof(T));
45+
std::memcpy(value, *buffer, sizeof(T));
4646
reinterpret_cast<char const*&>(*buffer) += sizeof(T);
4747
*buffer_size -= sizeof(T);
4848
}
4949
};
5050

5151
template <>
5252
struct Serializer<const char*> {
53-
static size_t serialized_size(const char* value) { return strlen(value) + 1; }
54-
static void serialize(void** buffer, const char* value) {
55-
::strcpy(static_cast<char*>(*buffer), value);
53+
static size_t SerializedSize(const char* value) { return strlen(value) + 1; }
54+
static void Serialize(void** buffer, const char* value) {
55+
std::strcpy(static_cast<char*>(*buffer), value);
5656
reinterpret_cast<char*&>(*buffer) += strlen(value) + 1;
5757
}
58-
static void deserialize(void const** buffer, size_t* buffer_size,
58+
static void Deserialize(void const** buffer, size_t* buffer_size,
5959
const char** value) {
6060
*value = static_cast<char const*>(*buffer);
6161
size_t data_size = strnlen(*value, *buffer_size) + 1;
@@ -70,23 +70,23 @@ struct Serializer<std::vector<T>,
7070
typename std::enable_if<std::is_arithmetic<T>::value ||
7171
std::is_enum<T>::value ||
7272
std::is_pod<T>::value>::type> {
73-
static size_t serialized_size(std::vector<T> const& value) {
73+
static size_t SerializedSize(std::vector<T> const& value) {
7474
return sizeof(value.size()) + value.size() * sizeof(T);
7575
}
76-
static void serialize(void** buffer, std::vector<T> const& value) {
77-
serialize_value(buffer, value.size());
76+
static void Serialize(void** buffer, std::vector<T> const& value) {
77+
SerializeValue(buffer, value.size());
7878
size_t nbyte = value.size() * sizeof(T);
79-
::memcpy(*buffer, value.data(), nbyte);
79+
std::memcpy(*buffer, value.data(), nbyte);
8080
reinterpret_cast<char*&>(*buffer) += nbyte;
8181
}
82-
static void deserialize(void const** buffer, size_t* buffer_size,
82+
static void Deserialize(void const** buffer, size_t* buffer_size,
8383
std::vector<T>* value) {
8484
size_t size;
85-
deserialize_value(buffer, buffer_size, &size);
85+
DeserializeValue(buffer, buffer_size, &size);
8686
value->resize(size);
8787
size_t nbyte = value->size() * sizeof(T);
8888
assert(*buffer_size >= nbyte);
89-
::memcpy(value->data(), *buffer, nbyte);
89+
std::memcpy(value->data(), *buffer, nbyte);
9090
reinterpret_cast<char const*&>(*buffer) += nbyte;
9191
*buffer_size -= nbyte;
9292
}
@@ -95,17 +95,17 @@ struct Serializer<std::vector<T>,
9595
} // namespace
9696

9797
template <typename T>
98-
inline size_t serialized_size(T const& value) {
99-
return Serializer<T>::serialized_size(value);
98+
inline size_t SerializedSize(T const& value) {
99+
return Serializer<T>::SerializedSize(value);
100100
}
101101

102102
template <typename T>
103-
inline void serialize_value(void** buffer, T const& value) {
104-
return Serializer<T>::serialize(buffer, value);
103+
inline void SerializeValue(void** buffer, T const& value) {
104+
return Serializer<T>::Serialize(buffer, value);
105105
}
106106

107107
template <typename T>
108-
inline void deserialize_value(void const** buffer, size_t* buffer_size,
109-
T* value) {
110-
return Serializer<T>::deserialize(buffer, buffer_size, value);
108+
inline void DeserializeValue(void const** buffer, size_t* buffer_size,
109+
T* value) {
110+
return Serializer<T>::Deserialize(buffer, buffer_size, value);
111111
}

paddle/fluid/inference/tensorrt/plugin/split_op_plugin.cu

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@ int SplitPlugin::initialize() {
3737
segment_offsets.push_back(segment_offsets.back() + output_length_[i]);
3838
}
3939
segment_offsets_ = segment_offsets;
40-
d_segment_offsets_ = segment_offsets;
4140
nvinfer1::Dims dims = this->getInputDims(0);
4241
nx_ = 1;
4342
for (int i = dims.nbDims - 1; i > axis_; --i) {
@@ -55,8 +54,6 @@ int SplitPlugin::enqueue(int batchSize, const void* const* inputs,
5554
void** outputs, void* workspace, cudaStream_t stream) {
5655
auto const& input_dims = this->getInputDims(0);
5756
int input_size = 0;
58-
int const* d_segment_offsets_ptr =
59-
thrust::raw_pointer_cast(&d_segment_offsets_[0]);
6057
float const* idata = reinterpret_cast<float const*>(inputs[0]);
6158
float** odatas = reinterpret_cast<float**>(outputs);
6259

paddle/fluid/inference/tensorrt/plugin/split_op_plugin.h

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414

1515
#pragma once
1616

17-
#include <thrust/device_vector.h>
1817
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
1918

2019
namespace paddle {
@@ -25,19 +24,21 @@ class SplitPlugin : public PluginTensorRT {
2524
int axis_;
2625
std::vector<int> output_length_;
2726
int nx_, ny_, nz_;
28-
thrust::device_vector<int> d_segment_offsets_;
2927
std::vector<int> segment_offsets_;
3028

3129
protected:
3230
virtual size_t getSerializationSize() override {
33-
return serialized_size(axis_) + serialized_size(output_length_) +
31+
return SerializedSize(axis_) + SerializedSize(output_length_) +
3432
getBaseSerializationSize();
3533
}
3634

35+
// TRT will call this func when we need to serialize the configuration of
36+
// tensorrt.
37+
// It should not be called by users.
3738
virtual void serialize(void *buffer) override {
3839
serializeBase(buffer);
39-
serialize_value(&buffer, axis_);
40-
serialize_value(&buffer, output_length_);
40+
SerializeValue(&buffer, axis_);
41+
SerializeValue(&buffer, output_length_);
4142
}
4243

4344
public:
@@ -46,10 +47,12 @@ class SplitPlugin : public PluginTensorRT {
4647
assert(axis <= nvinfer1::Dims::MAX_DIMS);
4748
}
4849

50+
// It was used for tensorrt deserialization.
51+
// It should not be called by users.
4952
SplitPlugin(void const *serialData, size_t serialLength) {
5053
deserializeBase(serialData, serialLength);
51-
deserialize_value(&serialData, &serialLength, &axis_);
52-
deserialize_value(&serialData, &serialLength, &output_length_);
54+
DeserializeValue(&serialData, &serialLength, &axis_);
55+
DeserializeValue(&serialData, &serialLength, &output_length_);
5356
}
5457

5558
SplitPlugin *clone() const override {
@@ -64,12 +67,6 @@ class SplitPlugin : public PluginTensorRT {
6467
virtual int initialize() override;
6568
virtual int enqueue(int batchSize, const void *const *inputs, void **outputs,
6669
void *workspace, cudaStream_t stream) override;
67-
68-
void setAxis(int axis) { axis_ = axis; }
69-
70-
void setOutputLengths(const std::vector<int> &output_lengths) {
71-
output_length_ = output_lengths;
72-
}
7370
};
7471

7572
} // tensorrt

paddle/fluid/inference/tensorrt/plugin/trt_plugin.cc

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,23 +19,23 @@ namespace inference {
1919
namespace tensorrt {
2020

2121
void PluginTensorRT::serializeBase(void*& buffer) {
22-
serialize_value(&buffer, input_dims_);
23-
serialize_value(&buffer, max_batch_size_);
24-
serialize_value(&buffer, data_type_);
25-
serialize_value(&buffer, data_format_);
22+
SerializeValue(&buffer, input_dims_);
23+
SerializeValue(&buffer, max_batch_size_);
24+
SerializeValue(&buffer, data_type_);
25+
SerializeValue(&buffer, data_format_);
2626
}
2727

2828
void PluginTensorRT::deserializeBase(void const*& serialData,
2929
size_t& serialLength) {
30-
deserialize_value(&serialData, &serialLength, &input_dims_);
31-
deserialize_value(&serialData, &serialLength, &max_batch_size_);
32-
deserialize_value(&serialData, &serialLength, &data_type_);
33-
deserialize_value(&serialData, &serialLength, &data_format_);
30+
DeserializeValue(&serialData, &serialLength, &input_dims_);
31+
DeserializeValue(&serialData, &serialLength, &max_batch_size_);
32+
DeserializeValue(&serialData, &serialLength, &data_type_);
33+
DeserializeValue(&serialData, &serialLength, &data_format_);
3434
}
3535

3636
size_t PluginTensorRT::getBaseSerializationSize() {
37-
return (serialized_size(input_dims_) + serialized_size(max_batch_size_) +
38-
serialized_size(data_type_) + serialized_size(data_format_));
37+
return (SerializedSize(input_dims_) + SerializedSize(max_batch_size_) +
38+
SerializedSize(data_type_) + SerializedSize(data_format_));
3939
}
4040

4141
bool PluginTensorRT::supportsFormat(nvinfer1::DataType type,

paddle/fluid/inference/tensorrt/plugin/trt_plugin.h

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -41,24 +41,32 @@ class PluginTensorRT : public nvinfer1::IPluginExt {
4141
size_t getWorkspaceSize(int) const override { return 0; }
4242
void terminate() override {}
4343
virtual ~PluginTensorRT() {}
44-
45-
// The following functions need to be overrided in the subclass.
46-
virtual nvinfer1::IPluginExt* clone() const = 0;
47-
virtual const char* getPluginType() const = 0;
48-
int initialize() override { return 0; }
44+
// Check format support. The default is FLOAT32 and NCHW.
4945
bool supportsFormat(nvinfer1::DataType type,
5046
nvinfer1::PluginFormat format) const override;
5147
void configureWithFormat(const nvinfer1::Dims* inputDims, int nbInputs,
5248
const nvinfer1::Dims* outputDims, int nbOutputs,
5349
nvinfer1::DataType type,
5450
nvinfer1::PluginFormat format,
5551
int maxBatchSize) override;
52+
53+
// *NOTE* The following functions need to be overrided in the subclass.
54+
virtual nvinfer1::IPluginExt* clone() const = 0;
55+
virtual const char* getPluginType() const = 0;
56+
// Initialize the layer for execution. This is called when the engine is
57+
// created.
58+
int initialize() override { return 0; }
59+
// Serialize the layer config to buffer.
5660
virtual void serialize(void* buffer) = 0;
5761
virtual size_t getSerializationSize() = 0;
62+
virtual int enqueue(int batchSize, const void* const* inputs, void** outputs,
63+
void* workspace, cudaStream_t stream) = 0;
5864

5965
protected:
66+
// Deserialize input_dims, max_batch_size, data_type, data_format
6067
void deserializeBase(void const*& serialData, size_t& serialLength);
6168
size_t getBaseSerializationSize();
69+
// Serialize input_dims, max_batch_size, data_type, data_format
6270
void serializeBase(void*& buffer);
6371

6472
std::vector<nvinfer1::Dims> input_dims_;

0 commit comments

Comments
 (0)