1
+ #include " cumsum_plugin.h"
2
+
3
+ using namespace nvinfer1 ;
4
+
5
+ namespace trtorch {
6
+ namespace core {
7
+ namespace conversion {
8
+ namespace converters {
9
+ namespace impl {
10
+ namespace plugins {
11
+
12
+ /*
13
+ * CumsumPlugin class implementations
14
+ */
15
+
16
+ CumsumPlugin::CumsumPlugin (int dim) : dim_(dim) {}
17
+
18
+ CumsumPlugin::CumsumPlugin (const char * data, size_t length) {
19
+ std::istringstream data_stream (std::string (data, length));
20
+
21
+ torch::serialize::InputArchive input_archive;
22
+ input_archive.load_from (data_stream);
23
+
24
+ {
25
+ torch::IValue value;
26
+ input_archive.read (" dim" , value);
27
+
28
+ dim_ = value.toInt ();
29
+ }
30
+ }
31
+
32
+ int CumsumPlugin::getNbOutputs () const {
33
+ return 1 ;
34
+ }
35
+
36
+ const char * CumsumPlugin::getPluginType () const {
37
+ return " Cumsum" ;
38
+ }
39
+
40
+ const char * CumsumPlugin::getPluginVersion () const {
41
+ return " 1" ;
42
+ }
43
+
44
+ const char * CumsumPlugin::getPluginNamespace () const {
45
+ return " " ;
46
+ }
47
+
48
+ nvinfer1::IPluginV2DynamicExt* CumsumPlugin::clone () const {
49
+ return new CumsumPlugin (dim_);
50
+ }
51
+
52
+ nvinfer1::DimsExprs CumsumPlugin::getOutputDimensions (
53
+ int outputIndex,
54
+ const nvinfer1::DimsExprs* inputs,
55
+ int nbInputs,
56
+ nvinfer1::IExprBuilder& exprBuilder) {
57
+ return inputs[0 ];
58
+ }
59
+
60
+ nvinfer1::DataType CumsumPlugin::getOutputDataType (int index, const nvinfer1::DataType* inputTypes, int nbInputs)
61
+ const {
62
+ return inputTypes[index];
63
+ }
64
+
65
+ int CumsumPlugin::initialize () {
66
+ #if NV_TENSORRT_MAJOR < 7 || (NV_TENSORRT_MAJOR == 7 && NV_TENSORRT_MINOR < 1)
67
+ tensor_options_ = tensor_options_.device (c10::kCUDA );
68
+ #else
69
+ tensor_options_ = tensor_options_.device (c10::kCPU );
70
+ #endif
71
+ return 0 ;
72
+ }
73
+
74
+ void CumsumPlugin::serialize (void * buffer) const {
75
+ std::string data = serializeToString ();
76
+ size_t size = getSerializationSize ();
77
+
78
+ data.copy ((char *)buffer, size);
79
+ }
80
+
81
+ std::string CumsumPlugin::serializeToString () const {
82
+ torch::serialize::OutputArchive output_archive;
83
+
84
+ output_archive.write (" dim" , torch::IValue (dim_));
85
+
86
+ std::ostringstream data_str;
87
+ output_archive.save_to (data_str);
88
+
89
+ return data_str.str ();
90
+ }
91
+
92
+ size_t CumsumPlugin::getSerializationSize () const {
93
+ return serializeToString ().size ();
94
+ }
95
+
96
+ bool CumsumPlugin::supportsFormatCombination (
97
+ int pos,
98
+ const nvinfer1::PluginTensorDesc* inOut,
99
+ int nbInputs,
100
+ int nbOutputs) {
101
+ TRTORCH_ASSERT (0 <= pos && pos <= 1 , " There should be exactly 2 connections to the plugin - 1 input, 1 output" );
102
+ TRTORCH_ASSERT (nbInputs == 1 , " Expected a single tensor as input to cumsum plugin" );
103
+ TRTORCH_ASSERT (nbOutputs == 1 , " Expected a single tensor as output to cumsum plugin" );
104
+
105
+ const PluginTensorDesc& in = inOut[0 ];
106
+
107
+ if (pos == 0 ) {
108
+ return (in.type == nvinfer1::DataType::kFLOAT || in.type == nvinfer1::DataType::kHALF ||
109
+ in.type == nvinfer1::DataType::kINT32 ) &&
110
+ (in.format == nvinfer1::TensorFormat::kLINEAR );
111
+ }
112
+
113
+ // pos == 1, accessing information about output tensor
114
+ const PluginTensorDesc& out = inOut[1 ];
115
+
116
+ return (in.type == out.type ) && (in.format == out.format );
117
+ }
118
+
119
+ void CumsumPlugin::configurePlugin (
120
+ const nvinfer1::DynamicPluginTensorDesc* in,
121
+ int nbInputs,
122
+ const nvinfer1::DynamicPluginTensorDesc* out,
123
+ int nbOutputs) {}
124
+
125
+ size_t CumsumPlugin::getWorkspaceSize (
126
+ const nvinfer1::PluginTensorDesc* inputs,
127
+ int nbInputs,
128
+ const nvinfer1::PluginTensorDesc* outputs,
129
+ int nbOutputs) const {
130
+ return 0 ;
131
+ }
132
+
133
+ int CumsumPlugin::enqueue (
134
+ const nvinfer1::PluginTensorDesc* inputDesc,
135
+ const nvinfer1::PluginTensorDesc* outputDesc,
136
+ const void * const * inputs,
137
+ void * const * outputs,
138
+ void * workspace,
139
+ cudaStream_t stream) {
140
+ tensor_options_ = tensor_options_.dtype (util::toATenDType (inputDesc[0 ].type ));
141
+ #if NV_TENSORRT_MAJOR < 7 || (NV_TENSORRT_MAJOR == 7 && NV_TENSORRT_MINOR < 1)
142
+ at::Tensor input = at::from_blob ((void *)inputs[0 ], util::toVec (inputDesc->dims ), [](void *) {}, tensor_options_);
143
+ at::Tensor output = at::from_blob (
144
+ outputs[0 ], util::volume (outputDesc->dims ), [](void *) {}, tensor_options_);
145
+
146
+ at::cuda::CUDAStream torch_stream = at::cuda::getStreamFromPool ();
147
+ at::cuda::CUDAStreamGuard torch_guard (torch_stream);
148
+
149
+ cudaEvent_t event;
150
+ cudaEventCreate (&event);
151
+ cudaEventRecord (event, stream);
152
+
153
+ cudaStreamWaitEvent (torch_stream.stream (), event, 0 );
154
+
155
+ at::cumsum_out (output, input, dim_);
156
+
157
+ cudaEvent_t torch_event;
158
+ cudaEventCreate (&torch_event);
159
+ cudaEventRecord (torch_event, torch_stream.stream ());
160
+
161
+ cudaStreamWaitEvent (stream, torch_event, 0 );
162
+
163
+ cudaEventDestroy (event);
164
+ cudaEventDestroy (torch_event);
165
+
166
+ return 0 ;
167
+ #else
168
+ // TODO: When PyTorch updates to cuDNN 8 try moving back to CUDA based ATen
169
+ // kernels HACK: WAR because there is a segfault if you try to create a CUDA
170
+ // Tensor in the context of TensorRT execution
171
+ float * input_blob = (float *)malloc (util::volume (inputDesc->dims ) * sizeof (float ));
172
+ cudaMemcpyAsync (
173
+ input_blob,
174
+ static_cast <const void *>(inputs[0 ]),
175
+ util::volume (inputDesc->dims ) * sizeof (float ),
176
+ cudaMemcpyDeviceToHost,
177
+ stream);
178
+ cudaStreamSynchronize (stream);
179
+
180
+ at::Tensor input = at::from_blob ((void *)input_blob, util::toVec (inputDesc->dims ), tensor_options_);
181
+ at::Tensor output;
182
+ output = at::cumsum (input, dim_);
183
+
184
+ cudaMemcpyAsync (
185
+ outputs[0 ], output.data_ptr (), util::volume (outputDesc->dims ) * sizeof (float ), cudaMemcpyHostToDevice, stream);
186
+ cudaStreamSynchronize (stream);
187
+
188
+ free (input_blob);
189
+
190
+ return 0 ;
191
+ #endif
192
+ }
193
+
194
+ /*
195
+ * CumsumPluginCreator class implementations
196
+ */
197
+ const char * CumsumPluginCreator::getPluginNamespace () const {
198
+ return " " ;
199
+ }
200
+
201
+ const char * CumsumPluginCreator::getPluginName () const {
202
+ return " Cumsum" ;
203
+ }
204
+
205
+ const char * CumsumPluginCreator::getPluginVersion () const {
206
+ return " 1" ;
207
+ }
208
+
209
+ nvinfer1::IPluginV2* CumsumPluginCreator::createPlugin (const char * name, const nvinfer1::PluginFieldCollection* fc) {
210
+ return nullptr ;
211
+ }
212
+
213
+ CumsumPlugin* CumsumPluginCreator::createPlugin (const char * name, int dim) {
214
+ name_ = name;
215
+ return new CumsumPlugin (dim);
216
+ }
217
+
218
+ nvinfer1::IPluginV2* CumsumPluginCreator::deserializePlugin (
219
+ const char * name,
220
+ const void * serialData,
221
+ size_t serialLength) {
222
+ name_ = name;
223
+ return new CumsumPlugin ((const char *)serialData, serialLength);
224
+ }
225
+
226
+ const nvinfer1::PluginFieldCollection* CumsumPluginCreator::getFieldNames () {
227
+ return nullptr ;
228
+ }
229
+
230
+ REGISTER_TENSORRT_PLUGIN (CumsumPluginCreator);
231
+
232
+ } // namespace plugins
233
+ } // namespace impl
234
+ } // namespace converters
235
+ } // namespace conversion
236
+ } // namespace core
237
+ } // namespace trtorch
0 commit comments