@@ -332,8 +332,8 @@ bool ModelImporter::supportsModel(void const *serialized_onnx_model,
332332 // Parse the graph and see if we hit any parsing errors
333333 allSupported = parse (serialized_onnx_model, serialized_onnx_model_size);
334334
335- size_t error_node = std::numeric_limits< size_t >:: max () ;
336- std::string input_node = " " ;
335+ int error_node = - 1 ;
336+ std::string input_node{} ;
337337
338338 if (!allSupported)
339339 {
@@ -343,7 +343,6 @@ bool ModelImporter::supportsModel(void const *serialized_onnx_model,
343343 nvonnxparser::IParserError const * error = getError (i);
344344 if (error->node () != -1 )
345345 {
346- cout << " Found unsupported node: " << error->node () << endl;
347346 error_node = error->node ();
348347 allSupported = false ;
349348 }
@@ -354,7 +353,6 @@ bool ModelImporter::supportsModel(void const *serialized_onnx_model,
354353 // Node name is extracted through error->file as all errors thrown on input nodes are wrapped
355354 // around MAKE_INPUT_ERROR.
356355 input_node = error->file ();
357- cout << " Found unsupported input: " << input_node << endl;
358356 }
359357 }
360358 }
@@ -369,18 +367,19 @@ bool ModelImporter::supportsModel(void const *serialized_onnx_model,
369367 {
370368 ::ONNX_NAMESPACE::NodeProto const & node = model.graph ().node (node_idx);
371369 // Check for connecting nodes to faulty input nodes and mark them as unsupported
372- bool contains_input = (input_node == " " ) ? false : check_for_input (node, input_node);
373- if (this ->supportsOperator (node.op_type ().c_str ()) && !contains_input)
370+ bool contains_input = (input_node.empty ()) ? false : check_for_input (node, input_node);
371+ bool contains_index = node_idx == error_node;
372+ if (this ->supportsOperator (node.op_type ().c_str ()) && !contains_input && !contains_index)
374373 {
375374 if (newSubGraph)
376375 {
377376 // If it is the beginning of a new subGraph, we start a new vector
378377 sub_graph_collection.emplace_back ();
379- // Mark all new graphs as "unknown "
378+ // Mark all new graphs as "unsupported "
380379 sub_graph_collection.back ().second = false ;
381380 newSubGraph = false ;
382381 }
383- // We add the new node to the last graph
382+ // Add supported nodes to the subgraph at the back.
384383 sub_graph_collection.back ().first .emplace_back (node_idx);
385384 }
386385 else
@@ -391,75 +390,6 @@ bool ModelImporter::supportsModel(void const *serialized_onnx_model,
391390 }
392391 }
393392
394- if (!allSupported)
395- {
396- // We hit some errors when parsing. Iterate through them to find the failing node.
397- int nerror = getNbErrors ();
398- for (int i = 0 ; i < nerror; ++i)
399- {
400- nvonnxparser::IParserError const * error = getError (i);
401- if (error->node () != -1 )
402- {
403- error_node = error->node ();
404- allSupported = false ;
405- }
406- // The node that we failed on is one of the input nodes (-1). Since TRT cannot parse the
407- // inputs return false.
408- else
409- {
410- return allSupported;
411- }
412- }
413- // Update the subgraph collection.
414- for (size_t graph_index = 0 ; graph_index < sub_graph_collection.size (); graph_index++)
415- {
416- NodesContainer_t subgraph = sub_graph_collection[graph_index].first ;
417-
418- // If we've already iterated past the error_node, all future graphs are unknown, so break
419- if (subgraph[0 ] > error_node)
420- {
421- break ;
422- }
423- for (size_t node_index = 0 ; node_index < subgraph.size (); node_index++)
424- {
425- // Split the graph at the node we hit an assertion at when parsing.
426- if (subgraph[node_index] == error_node)
427- {
428- // Case where subgraph has only one node and it's unsupported, simply delete it.
429- if (node_index == 0 && subgraph.size () == 1 )
430- {
431- sub_graph_collection.erase (sub_graph_collection.begin () + graph_index);
432- }
433- // Case where subgraph has more than one node and the first node is unsupported. No "split_before" graph.
434- // The split_after graph is marked as unsupported.
435- else if (node_index == 0 )
436- {
437- NodesContainer_t split_after (subgraph.begin () + node_index + 1 , subgraph.end ());
438- sub_graph_collection[graph_index].first = split_after;
439- sub_graph_collection[graph_index].second = false ;
440- }
441- // Case where subgraph has more than one node and the last node is unsupported. No "split_after" graph.
442- // Note due to potential shape tensor inputs, cannot mark the first subgraph as supported here.
443- else if (node_index == subgraph.size () - 1 )
444- {
445- NodesContainer_t split_before (subgraph.begin (), subgraph.begin () + node_index);
446- sub_graph_collection[graph_index].first = split_before;
447- }
448- // Case where unsupported node is somewhere in the middle. Split the subgraph at that point into two.
449- // Note due to potential shape tensor inputs, cannot mark the first subgraph as supported here.
450- else
451- {
452- NodesContainer_t split_before (subgraph.begin (), subgraph.begin () + node_index);
453- NodesContainer_t split_after (subgraph.begin () + node_index + 1 , subgraph.end ());
454- sub_graph_collection[graph_index].first = split_before;
455- sub_graph_collection.insert (sub_graph_collection.begin () + graph_index + 1 , std::make_pair (split_after, false ));
456- }
457- break ;
458- }
459- }
460- }
461- }
462-
463393 // Only mark the subgraph as supported if there is one supported subgraph.
464394 if (allSupported)
465395 {
@@ -565,7 +495,10 @@ ModelImporter::importModel(::ONNX_NAMESPACE::ModelProto const &model,
565495
566496 if (node.op_type () == " Shape" )
567497 {
498+ // Insert the node itself to catch ShapeLayer outputs.
568499 shapeTensors.insert ({node.name (), node_idx});
500+ // Shape layers should only have one output
501+ shapeTensors.insert ({node.output (0 ), node_idx});
569502 }
570503
571504 for ( size_t i=0 ; i<outputs.size (); ++i ) {
@@ -602,6 +535,7 @@ ModelImporter::importModel(::ONNX_NAMESPACE::ModelProto const &model,
602535 }
603536 nvinfer1::ITensor** user_output = _importer_ctx.getUserOutput (output.name ().c_str ());
604537 if ( !user_output ) {
538+ // Sanity check that the output is not the output of a shape layer
605539 auto outputName = output_tensor_ptr->getName ();
606540 if (shapeTensors.count (outputName))
607541 {
@@ -630,6 +564,25 @@ ModelImporter::importModel(::ONNX_NAMESPACE::ModelProto const &model,
630564 ASSERT (user_output.is_tensor (), ErrorCode::kINVALID_VALUE );
631565 *user_output_ptr = &user_output.tensor ();
632566 }
567+
568+ // Do a sanity check that all ShapeLayer outputs are being used as shape tensors.
569+ auto * network = _importer_ctx.network ();
570+
571+ for (int i = 0 ; i < network->getNbLayers (); i++)
572+ {
573+ auto * layer = network->getLayer (i);
574+ for (int j = 0 ; j < layer->getNbInputs (); j++)
575+ {
576+ auto * tensor = layer->getInput (j);
577+ auto tensorName = tensor->getName ();
578+ if (shapeTensors.count (tensorName) && !tensor->isShapeTensor ())
579+ {
580+ _current_node = shapeTensors.at (tensorName);
581+ ASSERT (false && " Shape layer outputs must be used as a shape tensor!" , ErrorCode::kUNSUPPORTED_GRAPH );
582+ }
583+ }
584+ }
585+
633586 return Status::success ();
634587}
635588
0 commit comments