@@ -9,57 +9,57 @@ namespace converters {
9
9
namespace impl {
10
10
namespace plugins {
11
11
12
- /*
12
+ /*
13
13
* InterpolatePlugin class implementations
14
14
*/
15
15
16
- InterpolatePlugin::InterpolatePlugin (std::vector<int64_t > in_shape, std::vector<int64_t > out_shape, std::vector<int64_t > size, std::string mode, bool align_corners) :
17
- in_shape (in_shape), out_shape (out_shape), size (size), mode (mode), align_corners (align_corners)
16
+ InterpolatePlugin::InterpolatePlugin (std::vector<int64_t > in_shape, std::vector<int64_t > out_shape, std::vector<int64_t > size, std::string mode, bool align_corners) :
17
+ in_shape_ (in_shape), out_shape_ (out_shape), size_ (size), mode_ (mode), align_corners_ (align_corners)
18
18
{}
19
19
20
20
InterpolatePlugin::InterpolatePlugin (const char *data, size_t length) {
21
21
std::istringstream data_stream (std::string (data, length));
22
-
22
+
23
23
torch::serialize::InputArchive input_archive;
24
24
input_archive.load_from (data_stream);
25
-
25
+
26
26
{
27
27
torch::IValue value;
28
28
input_archive.read (" in_shape" , value);
29
- in_shape = value.toIntVector ();
29
+ in_shape_ = value.toIntVector ();
30
30
}
31
31
{
32
32
torch::IValue value;
33
33
input_archive.read (" out_shape" , value);
34
- out_shape = value.toIntVector ();
34
+ out_shape_ = value.toIntVector ();
35
35
}
36
36
{
37
37
torch::IValue value;
38
38
input_archive.read (" size" , value);
39
- size = value.toIntVector ();
39
+ size_ = value.toIntVector ();
40
40
}
41
41
{
42
42
torch::IValue value;
43
43
input_archive.read (" mode" , value);
44
- mode = value.toStringRef ();
44
+ mode_ = value.toStringRef ();
45
45
}
46
46
{
47
47
torch::IValue value;
48
48
input_archive.read (" align_corners" , value);
49
- align_corners = value.toBool ();
49
+ align_corners_ = value.toBool ();
50
50
}
51
51
}
52
52
53
53
std::vector<int64_t > InterpolatePlugin::getInputShape () {
54
- return in_shape ;
54
+ return in_shape_ ;
55
55
}
56
56
57
57
std::vector<int64_t > InterpolatePlugin::getOutputShape () {
58
- return out_shape ;
58
+ return out_shape_ ;
59
59
}
60
60
61
61
std::vector<int64_t > InterpolatePlugin::getOutputSize () {
62
- return size ;
62
+ return size_ ;
63
63
}
64
64
65
65
int InterpolatePlugin::getNbOutputs () const {
@@ -80,14 +80,14 @@ const char* InterpolatePlugin::getPluginNamespace() const {
80
80
81
81
82
82
nvinfer1::IPluginV2DynamicExt* InterpolatePlugin::clone () const {
83
- return new InterpolatePlugin (in_shape, out_shape, size, mode, align_corners );
83
+ return new InterpolatePlugin (in_shape_, out_shape_, size_, mode_, align_corners_ );
84
84
}
85
85
86
86
nvinfer1::DimsExprs InterpolatePlugin::getOutputDimensions (int outputIndex, const nvinfer1::DimsExprs *inputs, int nbInputs, nvinfer1::IExprBuilder &exprBuilder) {
87
87
nvinfer1::DimsExprs output (inputs[0 ]);
88
88
89
- for (unsigned int i = 0 ; i < out_shape .size (); i++) {
90
- output.d [i] = exprBuilder.constant (out_shape [i]);
89
+ for (unsigned int i = 0 ; i < out_shape_ .size (); i++) {
90
+ output.d [i] = exprBuilder.constant (out_shape_ [i]);
91
91
}
92
92
93
93
return output;
@@ -98,10 +98,10 @@ nvinfer1::DataType InterpolatePlugin::getOutputDataType(int index, const nvinfer
98
98
}
99
99
100
100
int InterpolatePlugin::initialize () {
101
- tensor_options = tensor_options .device (c10::kCUDA );
101
+ tensor_options_ = tensor_options_ .device (c10::kCPU );
102
102
103
103
// c10::kFloat = FLOAT32
104
- tensor_options = tensor_options .dtype (c10::kFloat );
104
+ tensor_options_ = tensor_options_ .dtype (c10::kFloat );
105
105
106
106
return 0 ;
107
107
}
@@ -117,11 +117,11 @@ void InterpolatePlugin::serialize(void* buffer) const {
117
117
std::string InterpolatePlugin::serializeToString () const {
118
118
torch::serialize::OutputArchive output_archive;
119
119
120
- output_archive.write (" in_shape" , torch::IValue (in_shape ));
121
- output_archive.write (" out_shape" , torch::IValue (out_shape ));
122
- output_archive.write (" size" , torch::IValue (size ));
123
- output_archive.write (" mode" , torch::IValue (mode ));
124
- output_archive.write (" align_corners" , torch::IValue (align_corners ));
120
+ output_archive.write (" in_shape" , torch::IValue (in_shape_ ));
121
+ output_archive.write (" out_shape" , torch::IValue (out_shape_ ));
122
+ output_archive.write (" size" , torch::IValue (size_ ));
123
+ output_archive.write (" mode" , torch::IValue (mode_ ));
124
+ output_archive.write (" align_corners" , torch::IValue (align_corners_ ));
125
125
126
126
std::ostringstream data_str;
127
127
output_archive.save_to (data_str);
@@ -146,56 +146,48 @@ bool InterpolatePlugin::supportsFormatCombination(int pos, const nvinfer1::Plugi
146
146
147
147
// pos == 1, accessing information about output tensor
148
148
const PluginTensorDesc& out = inOut[1 ];
149
-
149
+
150
150
return (in.type == out.type ) && (in.format == out.format );
151
151
}
152
152
153
153
void InterpolatePlugin::configurePlugin (const nvinfer1::DynamicPluginTensorDesc* in, int nbInputs, const nvinfer1::DynamicPluginTensorDesc* out, int nbOutputs) {
154
- dtype = DataType::kFLOAT ;
154
+ dtype_ = DataType::kFLOAT ;
155
155
}
156
156
157
157
size_t InterpolatePlugin::getWorkspaceSize (const nvinfer1::PluginTensorDesc* inputs, int nbInputs, const nvinfer1::PluginTensorDesc* outputs, int nbOutputs) const {
158
158
return 0 ;
159
159
}
160
160
161
- int InterpolatePlugin::enqueue (const nvinfer1::PluginTensorDesc* inputDesc, const nvinfer1::PluginTensorDesc* outputDesc, const void * const * inputs,
162
- void * const * outputs, void * workspace,
161
+ int InterpolatePlugin::enqueue (const nvinfer1::PluginTensorDesc* inputDesc, const nvinfer1::PluginTensorDesc* outputDesc, const void * const * inputs,
162
+ void * const * outputs, void * workspace,
163
163
cudaStream_t stream) {
164
- at::Tensor input = at::from_blob ((void *) inputs[0 ], util::toVec (inputDesc->dims ), [](void *){}, tensor_options);
165
- at::Tensor output = at::from_blob (outputs[0 ], out_shape, [](void *){}, tensor_options);
166
-
167
- at::cuda::CUDAStream torch_stream = at::cuda::getStreamFromPool ();
168
- at::cuda::CUDAStreamGuard torch_guard (torch_stream);
169
-
170
- cudaEvent_t event;
171
- cudaEventCreate (&event);
172
- cudaEventRecord (event, stream);
173
-
174
- cudaStreamWaitEvent (torch_stream.stream (), event, 0 );
175
-
176
- if (mode == " linear" ) {
177
- at::upsample_linear1d_out (output, input, {size[0 ]}, align_corners);
178
- } else if (mode == " bilinear" ) {
179
- at::upsample_bilinear2d_out (output, input, {size[0 ], size[1 ]}, align_corners);
180
- } else if (mode == " trilinear" ) {
181
- at::upsample_trilinear3d_out (output, input, {size[0 ], size[1 ], size[2 ]}, align_corners);
182
- } else if (mode == " adaptive_pool2d" ) {
183
- at::adaptive_avg_pool2d_out (output, input, {size[0 ], size[1 ]});
164
+ // TODO: When PyTorch updates to cuDNN 8 try moving back to CUDA based ATen kernels
165
+ // HACK: WAR because there is a segfault if you try to create a CUDA Tensor in the context of TensorRT execution
166
+ float * input_blob = (float *) malloc (util::volume (inputDesc->dims ) * sizeof (float ));
167
+ cudaMemcpyAsync (input_blob, static_cast <const void *>(inputs[0 ]), util::volume (inputDesc->dims ) * sizeof (float ), cudaMemcpyDeviceToHost, stream);
168
+ cudaStreamSynchronize (stream);
169
+
170
+ at::Tensor input = at::from_blob ((void *)input_blob, util::toVec (inputDesc->dims ), tensor_options_);
171
+
172
+ at::Tensor output;
173
+ if (mode_ == " adaptive_pool2d" ) {
174
+ output = at::adaptive_avg_pool2d (input, {size_[0 ], size_[1 ]});
184
175
}
185
176
186
- cudaEvent_t torch_event;
187
- cudaEventCreate (&torch_event);
188
- cudaEventRecord (torch_event, torch_stream.stream ());
177
+ output = output.contiguous ();
178
+ for (int i = 0 ; i < util::volume (outputDesc->dims ); i++) {
179
+ std::cout << ((float *)output.data_ptr ())[i] << std::endl;
180
+ }
189
181
190
- cudaStreamWaitEvent (stream, torch_event, 0 );
182
+ cudaMemcpyAsync (outputs[0 ], output.data_ptr (), util::volume (outputDesc->dims ) * sizeof (float ), cudaMemcpyHostToDevice, stream);
183
+ cudaStreamSynchronize (stream);
191
184
192
- cudaEventDestroy (event);
193
- cudaEventDestroy (torch_event);
185
+ free (input_blob);
194
186
195
187
return 0 ;
196
188
}
197
189
198
- /*
190
+ /*
199
191
* InterpolatePluginCreator class implementations
200
192
*/
201
193
const char * InterpolatePluginCreator::getPluginNamespace () const {
@@ -214,15 +206,15 @@ nvinfer1::IPluginV2* InterpolatePluginCreator::createPlugin(const char* name, co
214
206
return nullptr ;
215
207
}
216
208
217
- InterpolatePlugin* InterpolatePluginCreator::createPlugin (const char * name, std::vector<int64_t > in_shape, std::vector<int64_t > out_shape,
218
- std::vector<int64_t > size,
209
+ InterpolatePlugin* InterpolatePluginCreator::createPlugin (const char * name, std::vector<int64_t > in_shape, std::vector<int64_t > out_shape,
210
+ std::vector<int64_t > size,
219
211
std::string mode, bool align_corners) {
220
- name = name;
212
+ name_ = name;
221
213
return new InterpolatePlugin (in_shape, out_shape, size, mode, align_corners);
222
214
}
223
215
224
216
nvinfer1::IPluginV2* InterpolatePluginCreator::deserializePlugin (const char * name, const void *serialData, size_t serialLength) {
225
- name = name;
217
+ name_ = name;
226
218
return new InterpolatePlugin ((const char *) serialData, serialLength);
227
219
}
228
220
0 commit comments