Skip to content

Commit 2066f53

Browse files
authored
Adding dynamic split support (#326)
1 parent fe0591c commit 2066f53

File tree

3 files changed

+78
-33
lines changed

3 files changed

+78
-33
lines changed

builtin_op_importers.cpp

Lines changed: 44 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1839,48 +1839,59 @@ DEFINE_BUILTIN_OP_IMPORTER(SpatialBN)
18391839

18401840
DEFINE_BUILTIN_OP_IMPORTER(Split)
18411841
{
1842-
ASSERT(inputs.size() == 1, ErrorCode::kUNSUPPORTED_NODE);
1843-
nvinfer1::ITensor* tensor_ptr = &convertToTensor(inputs.at(0), ctx);
1844-
nvinfer1::Dims dims = tensor_ptr->getDimensions();
1845-
int nbDims = dims.nbDims;
1842+
const int numOutputs = node.output().size();
1843+
1844+
nvinfer1::ITensor* tensorPtr = &convertToTensor(inputs.at(0), ctx);
1845+
const int rank = tensorPtr->getDimensions().nbDims;
1846+
nvinfer1::ITensor* shape = ctx->network()->addShape(*tensorPtr)->getOutput(0);
1847+
18461848
OnnxAttrs attrs(node);
18471849
int axis = attrs.get<int>("axis", 0);
1848-
TRT_CHECK(convert_axis(axis, nbDims));
1849-
std::vector<int> output_lengths;
1850-
int noutput = node.output().size();
1851-
std::vector<int> start_index (noutput, 0);
1850+
TRT_CHECK(convert_axis(axis, rank));
1851+
1852+
std::vector<int> outputLengths;
18521853
if (attrs.count("split"))
18531854
{
1854-
output_lengths = attrs.get<std::vector<int>>("split");
1855-
ASSERT(static_cast<int>(output_lengths.size()) == noutput, ErrorCode::kINVALID_NODE);
1856-
}
1857-
else
1858-
{
1859-
ASSERT(dims.d[axis] == -1 || dims.d[axis] % noutput == 0, ErrorCode::kINVALID_NODE);
1860-
output_lengths.assign(noutput, dims.d[axis] / noutput);
1861-
}
1862-
for (size_t i = 1; i < output_lengths.size(); i++)
1863-
{
1864-
start_index[i] = start_index[i - 1] + output_lengths[i - 1];
1855+
outputLengths = attrs.get<std::vector<int>>("split");
1856+
ASSERT(static_cast<int>(outputLengths.size()) == numOutputs, ErrorCode::kINVALID_NODE);
18651857
}
18661858

1867-
nvinfer1::Dims sliceStart = makeDims(nbDims, 0);
1868-
nvinfer1::Dims sliceSize = dims;
1869-
nvinfer1::Dims sliceStride = makeDims(nbDims, 1);
1870-
std::vector<TensorOrWeights> outputs;
1871-
for (int i = 0; i < noutput; ++i)
1872-
{
1873-
sliceStart.d[axis] = start_index[i];
1874-
sliceSize.d[axis] = output_lengths[i];
1875-
auto const layer = ctx->network()->addSlice(*tensor_ptr, sliceStart, sliceSize, sliceStride);
1876-
if (std::any_of(sliceSize.d, sliceSize.d + sliceSize.nbDims, [](int i){return i == -1;})){
1877-
layer->setInput(1, dimension_to_tensor(ctx, sliceStart));
1878-
layer->setInput(2, dimension_to_tensor(ctx, sliceSize));
1879-
layer->setInput(3, dimension_to_tensor(ctx, sliceStride));
1859+
nvinfer1::ITensor* startSliceAxis{addConstantScalar<int32_t>(ctx, 0, ::ONNX_NAMESPACE::TensorProto::INT32, nvinfer1::Dims{1, 1})->getOutput(0)};
1860+
// sizeSliceAxis = axisLength / numOutputs
1861+
nvinfer1::ITensor* sizeSliceAxis{ctx->network()->addElementWise(
1862+
*gatherDimension(ctx, shape, axis, nvinfer1::Dims{1, 1}),
1863+
*addConstantScalar(ctx, numOutputs, ::ONNX_NAMESPACE::TensorProto::INT32, nvinfer1::Dims{1, 1})->getOutput(0),
1864+
nvinfer1::ElementWiseOperation::kDIV
1865+
)->getOutput(0)};
1866+
1867+
nvinfer1::Dims zeroStartsDims{rank};
1868+
std::fill(zeroStartsDims.d, zeroStartsDims.d + zeroStartsDims.nbDims, 0);
1869+
nvinfer1::ITensor* zeroStarts = &makeShapeTensor(ctx, zeroStartsDims);
1870+
1871+
nvinfer1::Dims strides{rank};
1872+
std::fill(strides.d, strides.d + strides.nbDims, 1);
1873+
1874+
std::vector<TensorOrWeights> outputs{};
1875+
for (int i = 0; i < numOutputs; ++i)
1876+
{
1877+
if (!outputLengths.empty())
1878+
{
1879+
sizeSliceAxis = addConstantScalar(ctx, outputLengths.at(i), ::ONNX_NAMESPACE::TensorProto::INT32, nvinfer1::Dims{1, 1})->getOutput(0);
18801880
}
1881-
outputs.push_back(layer->getOutput(0));
1881+
1882+
nvinfer1::ITensor* starts{overwriteDim(ctx, zeroStarts, startSliceAxis, axis)};
1883+
nvinfer1::ITensor* sizes{overwriteDim(ctx, shape, sizeSliceAxis, axis)};
1884+
1885+
nvinfer1::ISliceLayer* slice = ctx->network()->addSlice(*tensorPtr, nvinfer1::Dims{rank}, nvinfer1::Dims{rank}, strides);
1886+
slice->setInput(1, *starts);
1887+
slice->setInput(2, *sizes);
1888+
outputs.emplace_back(slice->getOutput(0));
1889+
1890+
startSliceAxis = ctx->network()->addElementWise(*startSliceAxis, *sizeSliceAxis, nvinfer1::ElementWiseOperation::kSUM)->getOutput(0);
18821891
}
1892+
18831893
return outputs;
1894+
18841895
}
18851896

18861897
DEFINE_BUILTIN_OP_IMPORTER(Sqrt)

onnx2trt_utils.cpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -709,6 +709,12 @@ nvinfer1::ITensor* flattenTensor(IImporterContext* ctx, nvinfer1::ITensor& tenso
709709
return flattenTensorStatic(ctx, tensor, axis);
710710
}
711711

712+
nvinfer1::ITensor* gatherDimension(IImporterContext* ctx, nvinfer1::ITensor* shapeTensor, int dim, nvinfer1::Dims shape)
713+
{
714+
auto& axisValue = *addConstantScalar(ctx, dim, ::ONNX_NAMESPACE::TensorProto_DataType_INT32, shape)->getOutput(0);
715+
return ctx->network()->addGather(*shapeTensor, axisValue, 0)->getOutput(0);
716+
}
717+
712718
bool isDynamic(nvinfer1::Dims const& dims)
713719
{
714720
return std::any_of(dims.d, dims.d + dims.nbDims, [](int dim) {return dim == -1;});
@@ -954,6 +960,26 @@ nvinfer1::ITensor& makeShapeTensor(IImporterContext* ctx, nvinfer1::Dims dims)
954960
return convertToTensor(valueWeights, ctx);
955961
}
956962

963+
nvinfer1::ITensor* overwriteDim(IImporterContext* ctx, nvinfer1::ITensor* shape, nvinfer1::ITensor* dim, int axis)
964+
{
965+
const int shapeLength = shape->getDimensions().d[0];
966+
967+
std::vector<nvinfer1::ITensor*> dims{};
968+
969+
for (int i = 0; i < shapeLength; ++i)
970+
{
971+
if (i == axis)
972+
{
973+
dims.emplace_back(dim);
974+
}
975+
else
976+
{
977+
dims.emplace_back(gatherDimension(ctx, shape, i, nvinfer1::Dims{1, 1}));
978+
}
979+
}
980+
return ctx->network()->addConcatenation(dims.data(), dims.size())->getOutput(0);
981+
}
982+
957983
NodeImportResult poolingHelper(IImporterContext* ctx, ::ONNX_NAMESPACE::NodeProto const& node, std::vector<TensorOrWeights>& inputs, nvinfer1::PoolingType type)
958984
{
959985
nvinfer1::ITensor* tensorPtr = &convertToTensor(inputs.at(0), ctx);

onnx2trt_utils.hpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,10 @@ NodeImportResult elementwiseHelper(IImporterContext* ctx, ::ONNX_NAMESPACE::Node
167167
// Helper functino to flatten a tensor on a specified axis
168168
nvinfer1::ITensor* flattenTensor(IImporterContext* ctx, nvinfer1::ITensor& tensor, int axis);
169169

170+
// Gathers the specified dimension from a shape tensor. e.g. gatherDimension(shape=(7, 6, 5), dim=2) would return 5.
171+
// shape specifies the shape of the returned Tensor. Must have a volume of 1.
172+
nvinfer1::ITensor* gatherDimension(IImporterContext* ctx, nvinfer1::ITensor* shapeTensor, int dim, nvinfer1::Dims shape);
173+
170174
// Helper function to check if any input dimensions are dynamic
171175
bool isDynamic (nvinfer1::Dims const& dims);
172176

@@ -214,6 +218,10 @@ nvinfer1::Dims makeDims(int nbDims, int val);
214218
// Helper function to create a shape tensor from a Dims object for dynamic reshape
215219
nvinfer1::ITensor& makeShapeTensor(IImporterContext* ctx, nvinfer1::Dims dims);
216220

221+
222+
// Helper function to overwrite the value of a single axis in a shape tensor
223+
nvinfer1::ITensor* overwriteDim(IImporterContext* ctx, nvinfer1::ITensor* shape, nvinfer1::ITensor* dim, int axis);
224+
217225
// Helper function to map various ONNX pooling ops into TensorRT.
218226
NodeImportResult poolingHelper(IImporterContext* ctx, ::ONNX_NAMESPACE::NodeProto const& node, std::vector<TensorOrWeights>& inputs, nvinfer1::PoolingType type);
219227

0 commit comments

Comments
 (0)