Skip to content

Commit 5167c47

Browse files
committed
support cumsum converter
Signed-off-by: Ruoqian Guo <[email protected]>
1 parent 26d5c65 commit 5167c47

File tree

7 files changed

+489
-3
lines changed

7 files changed

+489
-3
lines changed

core/conversion/converters/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ cc_library(
3838
"impl/concat.cpp",
3939
"impl/constant.cpp",
4040
"impl/conv_deconv.cpp",
41+
"impl/cumsum.cpp",
4142
"impl/element_wise.cpp",
4243
"impl/expand.cpp",
4344
"impl/linear.cpp",
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
#include "NvInfer.h"
2+
#include "core/conversion/converters/converters.h"
3+
#include "core/conversion/tensorcontainer/TensorContainer.h"
4+
#include "core/util/prelude.h"
5+
#include "core/util/trt_util.h"
6+
#include "plugins/cumsum_plugin.h"
7+
#include "torch/torch.h"
8+
9+
#include <ATen/ATen.h>
10+
#include <vector>
11+
12+
namespace trtorch {
13+
namespace core {
14+
namespace conversion {
15+
namespace converters {
16+
namespace impl {
17+
namespace {
18+
19+
void create_plugin(ConversionCtx* ctx, const torch::jit::Node* n, nvinfer1::ITensor* in, const char* name, int dim) {
20+
LOG_WARNING("Cumsum layer will be run through ATen, not TensorRT. Performance may be lower than expected");
21+
22+
auto creator = new plugins::CumsumPluginCreator();
23+
auto plugin = creator->createPlugin(name, dim);
24+
25+
auto cumsum_layer = ctx->net->addPluginV2(reinterpret_cast<nvinfer1::ITensor* const*>(&in), 1, *plugin);
26+
TRTORCH_CHECK(cumsum_layer, "Unable to create cumsum plugin from node" << *n);
27+
28+
cumsum_layer->setName(util::node_info(n).c_str());
29+
30+
auto layer_output = ctx->AssociateValueAndTensor(n->outputs()[0], cumsum_layer->getOutput(0));
31+
32+
LOG_DEBUG("Output tensor shape: " << layer_output->getDimensions());
33+
}
34+
35+
auto cumsum_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns().pattern(
36+
{"aten::cumsum(Tensor self, int dim, *, int? dtype=None) -> (Tensor)",
37+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
38+
auto in = args[0].ITensor();
39+
auto input_dims = in->getDimensions();
40+
int dim = args[1].unwrapToInt();
41+
TRTORCH_CHECK(
42+
(dim >= 0 && dim < input_dims.nbDims) || (dim < 0 && (input_dims.nbDims + dim >= 0)),
43+
"Dimension out of range (expected to be in range of [" << -input_dims.nbDims << ", " << input_dims.nbDims - 1
44+
<< "], but got " << dim << ")");
45+
if (dim < 0) {
46+
dim += input_dims.nbDims;
47+
}
48+
create_plugin(ctx, n, in, "Cumsum", dim);
49+
return true;
50+
}});
51+
52+
} // namespace
53+
} // namespace impl
54+
} // namespace converters
55+
} // namespace conversion
56+
} // namespace core
57+
} // namespace trtorch

core/conversion/converters/impl/plugins/BUILD

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,12 @@ config_setting(
1010
cc_library(
1111
name = "plugins",
1212
hdrs = [
13-
"interpolate_plugin.h"
13+
"interpolate_plugin.h",
14+
"cumsum_plugin.h"
1415
],
1516
srcs = [
16-
"interpolate_plugin.cpp"
17+
"interpolate_plugin.cpp",
18+
"cumsum_plugin.cpp"
1719
],
1820
deps = [
1921
"@tensorrt//:nvinfer",
@@ -37,5 +39,5 @@ load("@rules_pkg//:pkg.bzl", "pkg_tar")
3739
pkg_tar(
3840
name = "include",
3941
package_dir = "core/conversion/converters/impl/plugins",
40-
srcs = ["interpolate_plugin.h"],
42+
srcs = ["interpolate_plugin.h", "cumsum_plugin.h"],
4143
)
Lines changed: 237 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,237 @@
1+
#include "cumsum_plugin.h"
2+
3+
using namespace nvinfer1;
4+
5+
namespace trtorch {
6+
namespace core {
7+
namespace conversion {
8+
namespace converters {
9+
namespace impl {
10+
namespace plugins {
11+
12+
/*
13+
* CumsumPlugin class implementations
14+
*/
15+
16+
CumsumPlugin::CumsumPlugin(int dim) : dim_(dim) {}
17+
18+
CumsumPlugin::CumsumPlugin(const char* data, size_t length) {
19+
std::istringstream data_stream(std::string(data, length));
20+
21+
torch::serialize::InputArchive input_archive;
22+
input_archive.load_from(data_stream);
23+
24+
{
25+
torch::IValue value;
26+
input_archive.read("dim", value);
27+
28+
dim_ = value.toInt();
29+
}
30+
}
31+
32+
int CumsumPlugin::getNbOutputs() const {
33+
return 1;
34+
}
35+
36+
const char* CumsumPlugin::getPluginType() const {
37+
return "Cumsum";
38+
}
39+
40+
const char* CumsumPlugin::getPluginVersion() const {
41+
return "1";
42+
}
43+
44+
const char* CumsumPlugin::getPluginNamespace() const {
45+
return "";
46+
}
47+
48+
nvinfer1::IPluginV2DynamicExt* CumsumPlugin::clone() const {
49+
return new CumsumPlugin(dim_);
50+
}
51+
52+
nvinfer1::DimsExprs CumsumPlugin::getOutputDimensions(
53+
int outputIndex,
54+
const nvinfer1::DimsExprs* inputs,
55+
int nbInputs,
56+
nvinfer1::IExprBuilder& exprBuilder) {
57+
return inputs[0];
58+
}
59+
60+
nvinfer1::DataType CumsumPlugin::getOutputDataType(int index, const nvinfer1::DataType* inputTypes, int nbInputs)
61+
const {
62+
return inputTypes[index];
63+
}
64+
65+
int CumsumPlugin::initialize() {
66+
#if NV_TENSORRT_MAJOR < 7 || (NV_TENSORRT_MAJOR == 7 && NV_TENSORRT_MINOR < 1)
67+
tensor_options_ = tensor_options_.device(c10::kCUDA);
68+
#else
69+
tensor_options_ = tensor_options_.device(c10::kCPU);
70+
#endif
71+
return 0;
72+
}
73+
74+
void CumsumPlugin::serialize(void* buffer) const {
75+
std::string data = serializeToString();
76+
size_t size = getSerializationSize();
77+
78+
data.copy((char*)buffer, size);
79+
}
80+
81+
std::string CumsumPlugin::serializeToString() const {
82+
torch::serialize::OutputArchive output_archive;
83+
84+
output_archive.write("dim", torch::IValue(dim_));
85+
86+
std::ostringstream data_str;
87+
output_archive.save_to(data_str);
88+
89+
return data_str.str();
90+
}
91+
92+
size_t CumsumPlugin::getSerializationSize() const {
93+
return serializeToString().size();
94+
}
95+
96+
bool CumsumPlugin::supportsFormatCombination(
97+
int pos,
98+
const nvinfer1::PluginTensorDesc* inOut,
99+
int nbInputs,
100+
int nbOutputs) {
101+
TRTORCH_ASSERT(0 <= pos && pos <= 1, "There should be exactly 2 connections to the plugin - 1 input, 1 output");
102+
TRTORCH_ASSERT(nbInputs == 1, "Expected a single tensor as input to cumsum plugin");
103+
TRTORCH_ASSERT(nbOutputs == 1, "Expected a single tensor as output to cumsum plugin");
104+
105+
const PluginTensorDesc& in = inOut[0];
106+
107+
if (pos == 0) {
108+
return (in.type == nvinfer1::DataType::kFLOAT || in.type == nvinfer1::DataType::kHALF ||
109+
in.type == nvinfer1::DataType::kINT32) &&
110+
(in.format == nvinfer1::TensorFormat::kLINEAR);
111+
}
112+
113+
// pos == 1, accessing information about output tensor
114+
const PluginTensorDesc& out = inOut[1];
115+
116+
return (in.type == out.type) && (in.format == out.format);
117+
}
118+
119+
void CumsumPlugin::configurePlugin(
120+
const nvinfer1::DynamicPluginTensorDesc* in,
121+
int nbInputs,
122+
const nvinfer1::DynamicPluginTensorDesc* out,
123+
int nbOutputs) {}
124+
125+
size_t CumsumPlugin::getWorkspaceSize(
126+
const nvinfer1::PluginTensorDesc* inputs,
127+
int nbInputs,
128+
const nvinfer1::PluginTensorDesc* outputs,
129+
int nbOutputs) const {
130+
return 0;
131+
}
132+
133+
int CumsumPlugin::enqueue(
134+
const nvinfer1::PluginTensorDesc* inputDesc,
135+
const nvinfer1::PluginTensorDesc* outputDesc,
136+
const void* const* inputs,
137+
void* const* outputs,
138+
void* workspace,
139+
cudaStream_t stream) {
140+
tensor_options_ = tensor_options_.dtype(util::toATenDType(inputDesc[0].type));
141+
#if NV_TENSORRT_MAJOR < 7 || (NV_TENSORRT_MAJOR == 7 && NV_TENSORRT_MINOR < 1)
142+
at::Tensor input = at::from_blob((void*)inputs[0], util::toVec(inputDesc->dims), [](void*) {}, tensor_options_);
143+
at::Tensor output = at::from_blob(
144+
outputs[0], util::volume(outputDesc->dims), [](void*) {}, tensor_options_);
145+
146+
at::cuda::CUDAStream torch_stream = at::cuda::getStreamFromPool();
147+
at::cuda::CUDAStreamGuard torch_guard(torch_stream);
148+
149+
cudaEvent_t event;
150+
cudaEventCreate(&event);
151+
cudaEventRecord(event, stream);
152+
153+
cudaStreamWaitEvent(torch_stream.stream(), event, 0);
154+
155+
at::cumsum_out(output, input, dim_);
156+
157+
cudaEvent_t torch_event;
158+
cudaEventCreate(&torch_event);
159+
cudaEventRecord(torch_event, torch_stream.stream());
160+
161+
cudaStreamWaitEvent(stream, torch_event, 0);
162+
163+
cudaEventDestroy(event);
164+
cudaEventDestroy(torch_event);
165+
166+
return 0;
167+
#else
168+
// TODO: When PyTorch updates to cuDNN 8 try moving back to CUDA based ATen
169+
// kernels HACK: WAR because there is a segfault if you try to create a CUDA
170+
// Tensor in the context of TensorRT execution
171+
float* input_blob = (float*)malloc(util::volume(inputDesc->dims) * sizeof(float));
172+
cudaMemcpyAsync(
173+
input_blob,
174+
static_cast<const void*>(inputs[0]),
175+
util::volume(inputDesc->dims) * sizeof(float),
176+
cudaMemcpyDeviceToHost,
177+
stream);
178+
cudaStreamSynchronize(stream);
179+
180+
at::Tensor input = at::from_blob((void*)input_blob, util::toVec(inputDesc->dims), tensor_options_);
181+
at::Tensor output;
182+
output = at::cumsum(input, dim_);
183+
184+
cudaMemcpyAsync(
185+
outputs[0], output.data_ptr(), util::volume(outputDesc->dims) * sizeof(float), cudaMemcpyHostToDevice, stream);
186+
cudaStreamSynchronize(stream);
187+
188+
free(input_blob);
189+
190+
return 0;
191+
#endif
192+
}
193+
194+
/*
195+
* CumsumPluginCreator class implementations
196+
*/
197+
const char* CumsumPluginCreator::getPluginNamespace() const {
198+
return "";
199+
}
200+
201+
const char* CumsumPluginCreator::getPluginName() const {
202+
return "Cumsum";
203+
}
204+
205+
const char* CumsumPluginCreator::getPluginVersion() const {
206+
return "1";
207+
}
208+
209+
nvinfer1::IPluginV2* CumsumPluginCreator::createPlugin(const char* name, const nvinfer1::PluginFieldCollection* fc) {
210+
return nullptr;
211+
}
212+
213+
CumsumPlugin* CumsumPluginCreator::createPlugin(const char* name, int dim) {
214+
name_ = name;
215+
return new CumsumPlugin(dim);
216+
}
217+
218+
nvinfer1::IPluginV2* CumsumPluginCreator::deserializePlugin(
219+
const char* name,
220+
const void* serialData,
221+
size_t serialLength) {
222+
name_ = name;
223+
return new CumsumPlugin((const char*)serialData, serialLength);
224+
}
225+
226+
const nvinfer1::PluginFieldCollection* CumsumPluginCreator::getFieldNames() {
227+
return nullptr;
228+
}
229+
230+
REGISTER_TENSORRT_PLUGIN(CumsumPluginCreator);
231+
232+
} // namespace plugins
233+
} // namespace impl
234+
} // namespace converters
235+
} // namespace conversion
236+
} // namespace core
237+
} // namespace trtorch

0 commit comments

Comments
 (0)