@@ -1839,48 +1839,59 @@ DEFINE_BUILTIN_OP_IMPORTER(SpatialBN)
18391839
18401840DEFINE_BUILTIN_OP_IMPORTER (Split)
18411841{
1842- ASSERT (inputs.size () == 1 , ErrorCode::kUNSUPPORTED_NODE );
1843- nvinfer1::ITensor* tensor_ptr = &convertToTensor (inputs.at (0 ), ctx);
1844- nvinfer1::Dims dims = tensor_ptr->getDimensions ();
1845- int nbDims = dims.nbDims ;
1842+ const int numOutputs = node.output ().size ();
1843+
1844+ nvinfer1::ITensor* tensorPtr = &convertToTensor (inputs.at (0 ), ctx);
1845+ const int rank = tensorPtr->getDimensions ().nbDims ;
1846+ nvinfer1::ITensor* shape = ctx->network ()->addShape (*tensorPtr)->getOutput (0 );
1847+
18461848 OnnxAttrs attrs (node);
18471849 int axis = attrs.get <int >(" axis" , 0 );
1848- TRT_CHECK (convert_axis (axis, nbDims));
1849- std::vector<int > output_lengths;
1850- int noutput = node.output ().size ();
1851- std::vector<int > start_index (noutput, 0 );
1850+ TRT_CHECK (convert_axis (axis, rank));
1851+
1852+ std::vector<int > outputLengths;
18521853 if (attrs.count (" split" ))
18531854 {
1854- output_lengths = attrs.get <std::vector<int >>(" split" );
1855- ASSERT (static_cast <int >(output_lengths.size ()) == noutput, ErrorCode::kINVALID_NODE );
1856- }
1857- else
1858- {
1859- ASSERT (dims.d [axis] == -1 || dims.d [axis] % noutput == 0 , ErrorCode::kINVALID_NODE );
1860- output_lengths.assign (noutput, dims.d [axis] / noutput);
1861- }
1862- for (size_t i = 1 ; i < output_lengths.size (); i++)
1863- {
1864- start_index[i] = start_index[i - 1 ] + output_lengths[i - 1 ];
1855+ outputLengths = attrs.get <std::vector<int >>(" split" );
1856+ ASSERT (static_cast <int >(outputLengths.size ()) == numOutputs, ErrorCode::kINVALID_NODE );
18651857 }
18661858
1867- nvinfer1::Dims sliceStart = makeDims (nbDims, 0 );
1868- nvinfer1::Dims sliceSize = dims;
1869- nvinfer1::Dims sliceStride = makeDims (nbDims, 1 );
1870- std::vector<TensorOrWeights> outputs;
1871- for (int i = 0 ; i < noutput; ++i)
1872- {
1873- sliceStart.d [axis] = start_index[i];
1874- sliceSize.d [axis] = output_lengths[i];
1875- auto const layer = ctx->network ()->addSlice (*tensor_ptr, sliceStart, sliceSize, sliceStride);
1876- if (std::any_of (sliceSize.d , sliceSize.d + sliceSize.nbDims , [](int i){return i == -1 ;})){
1877- layer->setInput (1 , dimension_to_tensor (ctx, sliceStart));
1878- layer->setInput (2 , dimension_to_tensor (ctx, sliceSize));
1879- layer->setInput (3 , dimension_to_tensor (ctx, sliceStride));
1859+ nvinfer1::ITensor* startSliceAxis{addConstantScalar<int32_t >(ctx, 0 , ::ONNX_NAMESPACE::TensorProto::INT32, nvinfer1::Dims{1 , 1 })->getOutput (0 )};
1860+ // sizeSliceAxis = axisLength / numOutputs
1861+ nvinfer1::ITensor* sizeSliceAxis{ctx->network ()->addElementWise (
1862+ *gatherDimension (ctx, shape, axis, nvinfer1::Dims{1 , 1 }),
1863+ *addConstantScalar (ctx, numOutputs, ::ONNX_NAMESPACE::TensorProto::INT32, nvinfer1::Dims{1 , 1 })->getOutput (0 ),
1864+ nvinfer1::ElementWiseOperation::kDIV
1865+ )->getOutput (0 )};
1866+
1867+ nvinfer1::Dims zeroStartsDims{rank};
1868+ std::fill (zeroStartsDims.d , zeroStartsDims.d + zeroStartsDims.nbDims , 0 );
1869+ nvinfer1::ITensor* zeroStarts = &makeShapeTensor (ctx, zeroStartsDims);
1870+
1871+ nvinfer1::Dims strides{rank};
1872+ std::fill (strides.d , strides.d + strides.nbDims , 1 );
1873+
1874+ std::vector<TensorOrWeights> outputs{};
1875+ for (int i = 0 ; i < numOutputs; ++i)
1876+ {
1877+ if (!outputLengths.empty ())
1878+ {
1879+ sizeSliceAxis = addConstantScalar (ctx, outputLengths.at (i), ::ONNX_NAMESPACE::TensorProto::INT32, nvinfer1::Dims{1 , 1 })->getOutput (0 );
18801880 }
1881- outputs.push_back (layer->getOutput (0 ));
1881+
1882+ nvinfer1::ITensor* starts{overwriteDim (ctx, zeroStarts, startSliceAxis, axis)};
1883+ nvinfer1::ITensor* sizes{overwriteDim (ctx, shape, sizeSliceAxis, axis)};
1884+
1885+ nvinfer1::ISliceLayer* slice = ctx->network ()->addSlice (*tensorPtr, nvinfer1::Dims{rank}, nvinfer1::Dims{rank}, strides);
1886+ slice->setInput (1 , *starts);
1887+ slice->setInput (2 , *sizes);
1888+ outputs.emplace_back (slice->getOutput (0 ));
1889+
1890+ startSliceAxis = ctx->network ()->addElementWise (*startSliceAxis, *sizeSliceAxis, nvinfer1::ElementWiseOperation::kSUM )->getOutput (0 );
18821891 }
1892+
18831893 return outputs;
1894+
18841895}
18851896
18861897DEFINE_BUILTIN_OP_IMPORTER (Sqrt)
0 commit comments