@@ -38,9 +38,8 @@ static ProtobufShutter protobufShutter;
3838void setTensorLocations (
3939 ImporterContext* ctx, std::vector<std::string> const & tensors, std::vector<std::string> const & locations)
4040{
41- ONNXTRT_CHECK ((tensors.size () >= locations.size ())
42- && " The size of tensors misaligns with the size of the attribute trt_outputs_loc." ,
43- nvonnxparser::ErrorCode::kINVALID_GRAPH );
41+ ONNXTRT_CHECK (tensors.size () >= locations.size (),
42+ " The size of tensors misaligns with the size of the attribute trt_outputs_loc." , ErrorCode::kINVALID_GRAPH );
4443 for (size_t i = 0 ; i < locations.size (); ++i)
4544 {
4645 std::string tensor = tensors.at (i);
@@ -50,8 +49,8 @@ void setTensorLocations(
5049
5150 if (ctx->tensorLocations ().count (tensor) > 0 )
5251 {
53- ONNXTRT_CHECK (( ctx->tensorLocations ()[tensor] == loc) && " The tensor location cannot be changed." ,
54- nvonnxparser:: ErrorCode::kINVALID_GRAPH );
52+ ONNXTRT_CHECK (ctx->tensorLocations ()[tensor] == loc, " The tensor location cannot be changed." ,
53+ ErrorCode::kINVALID_GRAPH );
5554 }
5655 else
5756 {
@@ -65,16 +64,19 @@ template <typename T>
6564void setStringMap (
6665 ImporterContext* ctx, std::vector<std::string> const & tensors, std::vector<T> const & data, StringMap<T>& map)
6766{
68- ONNXTRT_CHECK (( tensors.size () >= data.size ())
69- && " The size of tensors misaligns with the size of the attribute trt_outputs_range_min/max." ,
70- nvonnxparser:: ErrorCode::kINVALID_GRAPH );
67+ ONNXTRT_CHECK (tensors.size () >= data.size (),
68+ " The size of tensors misaligns with the size of the attribute trt_outputs_range_min/max." ,
69+ ErrorCode::kINVALID_GRAPH );
7170 for (size_t i = 0 ; i < data.size (); ++i)
7271 {
7372 std::string name = tensors.at (i);
7473 T dataName = data.at (i);
7574 if (map.count (name) > 0 )
7675 {
77- ONNXTRT_CHECK ( (map[name] == dataName) && " The order of tensorRangeMin/Max in context misaligns with the order of the attribute trt_outputs_range_min/max." , nvonnxparser::ErrorCode::kINVALID_GRAPH );
76+ ONNXTRT_CHECK (map[name] == dataName,
77+ " The order of tensorRangeMin/Max in context misaligns with the order of the attribute "
78+ " trt_outputs_range_min/max." ,
79+ ErrorCode::kINVALID_GRAPH );
7880 }
7981 else
8082 {
@@ -163,7 +165,14 @@ void parseNode(
163165 LOG_VERBOSE (ssInputs.str ());
164166
165167 // UINT8 weights that are not Q/DQ inputs will be converted to INT32
166- if (node.op_type () != " QuantizeLinear" && node.op_type () != " DequantizeLinear" )
168+ // If the UINT8 quantization flag is enabled, constants with UINT8 will also be permitted.
169+ uint32_t uint8AsymmetricQuantizationFlag = 1U
170+ << static_cast <uint32_t >(nvonnxparser::OnnxParserFlag::kENABLE_UINT8_AND_ASYMMETRIC_QUANTIZATION_DLA );
171+ bool allowUint8Quantization = ctx->getFlags () & uint8AsymmetricQuantizationFlag;
172+
173+ bool skipUInt8Conversion = (node.op_type () == " QuantizeLinear" || node.op_type () == " DequantizeLinear"
174+ || (allowUint8Quantization && node.op_type () == " Constant" ));
175+ if (!skipUInt8Conversion)
167176 {
168177 for (auto & nodeInput : nodeInputs)
169178 {
@@ -289,20 +298,26 @@ void parseNode(
289298 {
290299 ctx->registerTensor (std::move (output), outputName);
291300 }
292- // UINT8 is only allowed as network inputs, network outputs, and constants for QDQ nodes. Therefore any
293- // non-constant node that produces an UINT8-typed output that is not also a graph output is unsupported.
294- if (output.getType () == " UINT8" && node.op_type () != " Constant" )
301+ // UINT8 is only allowed as network inputs, network outputs, and constants for QDQ nodes unless the UINT8
302+ // quantization flag is set. If the UINT8 quantization flag is set, then UINT8 is also permitted as a
303+ // QuantizeLinear output or Gather output (when they feed into a dequantize node). Other than the cases listed,
304+ // any non-constant node that produces an UINT8-typed output that is not also a graph output is unsupported.
305+ if (output.getType () == " UINT8" )
295306 {
296- bool legalUINT8 = false ;
307+ bool legalUINT8 = node.op_type () == " Constant"
308+ || (allowUint8Quantization && (node.op_type () == " Gather" || node.op_type () == " QuantizeLinear" ));
297309 for (auto const & graphOutput : ctx->getGraphOutputNames ())
298310 {
299311 if (graphOutput.name () == outputName)
300312 {
301313 legalUINT8 = true ;
314+ break ;
302315 }
303316 }
304- ONNXTRT_CHECK_NODE (legalUINT8, " TensorRT does not support UINT8 types for intermediate tensors!" , node,
305- nodeIdx, ErrorCode::kUNSUPPORTED_NODE );
317+ ONNXTRT_CHECK_NODE (legalUINT8,
318+ " TensorRT does not support UINT8 types for intermediate tensors. For UINT8 quantization, the "
319+ " kIMPORT_UINT8_QUANTIZATION flag must be set. (DLA version >= 3.16 only)" ,
320+ node, nodeIdx, ErrorCode::kUNSUPPORTED_NODE );
306321 }
307322 trtCnt++;
308323 }
@@ -366,9 +381,8 @@ void parseGraph(ImporterContext* ctx, ::ONNX_NAMESPACE::GraphProto const& graph,
366381 {
367382 LOG_VERBOSE (" Importing initializer: " << initializer.name ());
368383 ShapedWeights weights;
369- ONNXTRT_CHECK (
370- ctx->getWeightsContext ().convertOnnxWeights (initializer, &weights) && " Failed to import initializer." ,
371- ErrorCode::kUNSUPPORTED_NODE );
384+ ONNXTRT_CHECK (ctx->getWeightsContext ().convertOnnxWeights (initializer, &weights),
385+ " Failed to import initializer: " << initializer.name (), ErrorCode::kUNSUPPORTED_NODE );
372386 ctx->registerTensor (TensorOrWeights{std::move (weights)}, initializer.name ());
373387 }
374388 }
@@ -385,7 +399,7 @@ void parseGraph(ImporterContext* ctx, ::ONNX_NAMESPACE::GraphProto const& graph,
385399
386400 std::vector<size_t > topoOrder;
387401 ONNXTRT_CHECK (
388- toposort (graph.node (), &topoOrder) && " Failed to sort the model topologically." , ErrorCode::kINVALID_GRAPH );
402+ toposort (graph.node (), &topoOrder), " Failed to sort the model topologically." , ErrorCode::kINVALID_GRAPH );
389403
390404 for (auto const & nodeIndex : topoOrder)
391405 {
@@ -682,7 +696,7 @@ bool ModelImporter::isSubgraphSupported(int64_t const index) noexcept
682696 errorMessage << " Query index " << index
683697 << " exceeds subgraph support vector (size = " << mSubGraphSupportVector .size ()
684698 << " ). Have you called supports_model_v2?" ;
685- ONNXTRT_CHECK (mSubGraphSupportVector .size () > static_cast <uint64_t >(index) && errorMessage.str (). c_str (),
699+ ONNXTRT_CHECK (mSubGraphSupportVector .size () > static_cast <uint64_t >(index), errorMessage.str (),
686700 ErrorCode::kINVALID_VALUE );
687701 return mSubGraphSupportVector [index].second ;
688702 }
@@ -698,7 +712,7 @@ int64_t* ModelImporter::getSubgraphNodes(int64_t const index, int64_t& subgraphL
698712 errorMessage << " Query index " << index
699713 << " exceeds subgraph support vector (size = " << mSubGraphSupportVector .size ()
700714 << " ). Have you called supports_model_v2?" ;
701- ONNXTRT_CHECK (mSubGraphSupportVector .size () > static_cast <uint64_t >(index) && errorMessage.str (). c_str (),
715+ ONNXTRT_CHECK (mSubGraphSupportVector .size () > static_cast <uint64_t >(index), errorMessage.str (),
702716 ErrorCode::kINVALID_VALUE );
703717 subgraphLength = mSubGraphSupportVector [index].first .size ();
704718 return mSubGraphSupportVector [index].first .data ();
@@ -769,8 +783,8 @@ void ModelImporter::importModel(::ONNX_NAMESPACE::ModelProto const& model)
769783 mImporterCtx .clearOpsets ();
770784 // Add domain import limit for security reasons
771785 int32_t const MAX_DOMAINS = 1024 ;
772- ONNXTRT_CHECK (model.opset_import ().size () <= MAX_DOMAINS
773- && " Model contains more than 1024 domains! Parsing will halt for security reasons." ,
786+ ONNXTRT_CHECK (model.opset_import ().size () <= MAX_DOMAINS,
787+ " Model contains more than 1024 domains! Parsing will halt for security reasons." ,
774788 ErrorCode::kUNSUPPORTED_GRAPH );
775789 for (int32_t i = 0 ; i < model.opset_import ().size (); ++i)
776790 {
@@ -808,8 +822,8 @@ void ModelImporter::importModel(::ONNX_NAMESPACE::ModelProto const& model)
808822 // Mark outputs defined in the ONNX model (unless tensors are user-requested)
809823 for (::ONNX_NAMESPACE::ValueInfoProto const & output : graph.output ())
810824 {
811- ONNXTRT_CHECK ((mImporterCtx .tensors ().count (output.name ())) && " The output tensor was not registered. " ,
812- ErrorCode::kINVALID_GRAPH );
825+ ONNXTRT_CHECK ((mImporterCtx .tensors ().count (output.name ())),
826+ " The output tensor " << output. name () << " was not registered. " , ErrorCode::kINVALID_GRAPH );
813827 nvinfer1::ITensor* output_tensor_ptr
814828 = &convertToTensor (mImporterCtx .tensors ().at (output.name ()), &mImporterCtx );
815829 LOG_VERBOSE (" Marking " << output_tensor_ptr->getName () << " as output: " << output.name ());
@@ -821,21 +835,19 @@ void ModelImporter::importModel(::ONNX_NAMESPACE::ModelProto const& model)
821835 // TODO: Does this break things by changing the name of the input tensor?
822836 output_tensor_ptr->setName ((" __" + output.name ()).c_str ());
823837 output_tensor_ptr = &identity (&mImporterCtx , output_tensor_ptr).tensor ();
824- ONNXTRT_CHECK (output_tensor_ptr && " Failed to add an Identity layer." , ErrorCode::kUNSUPPORTED_NODE );
838+ ONNXTRT_CHECK (output_tensor_ptr, " Failed to add an Identity layer." , ErrorCode::kUNSUPPORTED_NODE );
825839 output_tensor_ptr->setName (output.name ().c_str ());
826840 }
827841
828842 mImporterCtx .network ()->markOutput (*output_tensor_ptr);
829843 nvinfer1::DataType output_trt_dtype;
830844
831- ONNXTRT_CHECK (convertDtype (output.type ().tensor_type ().elem_type (), &output_trt_dtype)
832- && " Failed to convert ONNX date type to TensorRT data type." ,
833- ErrorCode::kUNSUPPORTED_NODE );
845+ ONNXTRT_CHECK (convertDtype (output.type ().tensor_type ().elem_type (), &output_trt_dtype),
846+ " Failed to convert ONNX date type to TensorRT data type." , ErrorCode::kUNSUPPORTED_NODE );
834847 // For INT32 data type, output type must match tensor type
835848 ONNXTRT_CHECK ((output_tensor_ptr->getType () != nvinfer1::DataType::kINT32
836- || output_trt_dtype == nvinfer1::DataType::kINT32 )
837- && " For INT32 tensors, the output type must also be INT32." ,
838- ErrorCode::kUNSUPPORTED_NODE );
849+ || output_trt_dtype == nvinfer1::DataType::kINT32 ),
850+ " For INT32 tensors, the output type must also be INT32." , ErrorCode::kUNSUPPORTED_NODE );
839851 // Note: Without this, output type is always float32
840852 output_tensor_ptr->setType (output_trt_dtype);
841853 if (output_trt_dtype == nvinfer1::DataType::kINT64 )
@@ -890,15 +902,15 @@ void ModelImporter::importModel(::ONNX_NAMESPACE::ModelProto const& model)
890902 // Set locations for all tensors
891903 for (auto const & tensor : ctx->tensorLocations ())
892904 {
893- ONNXTRT_CHECK ((tensors.count (tensor.first ) > 0 ) && " The tensor does not have an assigned location." ,
905+ ONNXTRT_CHECK ((tensors.count (tensor.first ) > 0 ), " The tensor does not have an assigned location." ,
894906 nvonnxparser::ErrorCode::kINVALID_GRAPH );
895907 tensors.at (tensor.first )->setLocation (tensor.second );
896908 }
897909 // Set dynamic range for all tensors
898910 for (auto const & tensor : ctx->tensorRangeMins ())
899911 {
900912 // if there's a min range, there must be a max range as well
901- ONNXTRT_CHECK ((tensors.count (tensor.first ) > 0 ) && " The tensor does not have an assigned location ." ,
913+ ONNXTRT_CHECK ((tensors.count (tensor.first ) > 0 ), " The tensor does not have its dynamic range set ." ,
902914 nvonnxparser::ErrorCode::kINVALID_GRAPH );
903915 if (!std::isnan (tensor.second ))
904916 {
@@ -911,7 +923,7 @@ void ModelImporter::importModel(::ONNX_NAMESPACE::ModelProto const& model)
911923 // Set precisions for all layers.
912924 for (auto const & layer : ctx->layerPrecisions ())
913925 {
914- ONNXTRT_CHECK ((layers.count (layer.first ) > 0 ) && " The layer does not have an assigned precision." ,
926+ ONNXTRT_CHECK ((layers.count (layer.first ) > 0 ), " The layer does not have an assigned precision." ,
915927 nvonnxparser::ErrorCode::kINVALID_GRAPH );
916928 layers.at (layer.first )->setPrecision (layer.second );
917929 }
@@ -932,6 +944,7 @@ bool ModelImporter::parseFromFile(char const* onnxModelFile, int32_t verbosity)
932944{
933945 ONNXTRT_TRY
934946 {
947+ ONNXTRT_CHECK (onnxModelFile, " Input file cannot be empty." , ErrorCode::kINVALID_VALUE );
935948 auto * ctx = &mImporterCtx ;
936949
937950 // Define S_ISREG macro for Windows
@@ -940,23 +953,16 @@ bool ModelImporter::parseFromFile(char const* onnxModelFile, int32_t verbosity)
940953#endif
941954
942955 struct stat sb;
943- if (stat (onnxModelFile, &sb) == 0 && !S_ISREG (sb.st_mode ))
944- {
945- LOG_ERROR (" Input is not a regular file: " << onnxModelFile);
946- return false ;
947- }
956+ ONNXTRT_CHECK (stat (onnxModelFile, &sb) == 0 && S_ISREG (sb.st_mode ),
957+ " Input file cannot be found, or is not a regular file: " << onnxModelFile, ErrorCode::kINVALID_VALUE );
948958
949959 GOOGLE_PROTOBUF_VERIFY_VERSION;
950960
951961 // Own the ONNX model for weights to persist.
952962 mONNXModels .emplace_back ();
953963 ::ONNX_NAMESPACE::ModelProto& onnxModel = mONNXModels .back ();
954- bool const fileLoadSuccess = ParseFromFileAsBinary (&onnxModel, onnxModelFile);
955- if (!fileLoadSuccess)
956- {
957- LOG_ERROR (" Failed to parse ONNX model from file: " << onnxModelFile << " !" );
958- return false ;
959- }
964+ ONNXTRT_CHECK (ParseFromFileAsBinary (&onnxModel, onnxModelFile),
965+ " Cannot read from input file: " << onnxModelFile, ErrorCode::kINVALID_VALUE );
960966
961967 // Keep track of the absolute path to the ONNX file.
962968 mImporterCtx .setOnnxFileLocation (onnxModelFile);
0 commit comments