Skip to content

Commit fe0591c

Browse files
authored
Simplifying logic for supportsModel and adds logic for handling shape layer outputs (#324)
1 parent b7c12a8 commit fe0591c

File tree

1 file changed

+30
-77
lines changed

1 file changed

+30
-77
lines changed

ModelImporter.cpp

Lines changed: 30 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -332,8 +332,8 @@ bool ModelImporter::supportsModel(void const *serialized_onnx_model,
332332
// Parse the graph and see if we hit any parsing errors
333333
allSupported = parse(serialized_onnx_model, serialized_onnx_model_size);
334334

335-
size_t error_node = std::numeric_limits<size_t>::max();
336-
std::string input_node = "";
335+
int error_node = -1;
336+
std::string input_node{};
337337

338338
if (!allSupported)
339339
{
@@ -343,7 +343,6 @@ bool ModelImporter::supportsModel(void const *serialized_onnx_model,
343343
nvonnxparser::IParserError const* error = getError(i);
344344
if (error->node() != -1)
345345
{
346-
cout << "Found unsupported node: " << error->node() << endl;
347346
error_node = error->node();
348347
allSupported = false;
349348
}
@@ -354,7 +353,6 @@ bool ModelImporter::supportsModel(void const *serialized_onnx_model,
354353
// Node name is extracted through error->file as all errors thrown on input nodes are wrapped
355354
// around MAKE_INPUT_ERROR.
356355
input_node = error->file();
357-
cout << "Found unsupported input: " << input_node << endl;
358356
}
359357
}
360358
}
@@ -369,18 +367,19 @@ bool ModelImporter::supportsModel(void const *serialized_onnx_model,
369367
{
370368
::ONNX_NAMESPACE::NodeProto const& node = model.graph().node(node_idx);
371369
// Check for connecting nodes to faulty input nodes and mark them as unsupported
372-
bool contains_input = (input_node == "") ? false : check_for_input(node, input_node);
373-
if (this->supportsOperator(node.op_type().c_str()) && !contains_input)
370+
bool contains_input = (input_node.empty()) ? false : check_for_input(node, input_node);
371+
bool contains_index = node_idx == error_node;
372+
if (this->supportsOperator(node.op_type().c_str()) && !contains_input && !contains_index)
374373
{
375374
if (newSubGraph)
376375
{
377376
// If it is the beginning of a new subGraph, we start a new vector
378377
sub_graph_collection.emplace_back();
379-
// Mark all new graphs as "unknown"
378+
// Mark all new graphs as "unsupported"
380379
sub_graph_collection.back().second = false;
381380
newSubGraph = false;
382381
}
383-
// We add the new node to the last graph
382+
// Add supported nodes to the subgraph at the back.
384383
sub_graph_collection.back().first.emplace_back(node_idx);
385384
}
386385
else
@@ -391,75 +390,6 @@ bool ModelImporter::supportsModel(void const *serialized_onnx_model,
391390
}
392391
}
393392

394-
if (!allSupported)
395-
{
396-
// We hit some errors when parsing. Iterate through them to find the failing node.
397-
int nerror = getNbErrors();
398-
for (int i = 0; i < nerror; ++i)
399-
{
400-
nvonnxparser::IParserError const* error = getError(i);
401-
if (error->node() != -1)
402-
{
403-
error_node = error->node();
404-
allSupported = false;
405-
}
406-
// The node that we failed on is one of the input nodes (-1). Since TRT cannot parse the
407-
// inputs return false.
408-
else
409-
{
410-
return allSupported;
411-
}
412-
}
413-
// Update the subgraph collection.
414-
for (size_t graph_index = 0; graph_index < sub_graph_collection.size(); graph_index++)
415-
{
416-
NodesContainer_t subgraph = sub_graph_collection[graph_index].first;
417-
418-
// If we've already iterated past the error_node, all future graphs are unknown, so break
419-
if (subgraph[0] > error_node)
420-
{
421-
break;
422-
}
423-
for (size_t node_index = 0; node_index < subgraph.size(); node_index++)
424-
{
425-
// Split the graph at the node we hit an assertion at when parsing.
426-
if (subgraph[node_index] == error_node)
427-
{
428-
// Case where subgraph has only one node and it's unsupported, simply delete it.
429-
if (node_index == 0 && subgraph.size() == 1)
430-
{
431-
sub_graph_collection.erase(sub_graph_collection.begin() + graph_index);
432-
}
433-
// Case where subgraph has more than one node and the first node is unsupported. No "split_before" graph.
434-
// The split_after graph is marked as unsupported.
435-
else if (node_index == 0)
436-
{
437-
NodesContainer_t split_after (subgraph.begin() + node_index + 1, subgraph.end());
438-
sub_graph_collection[graph_index].first = split_after;
439-
sub_graph_collection[graph_index].second = false;
440-
}
441-
// Case where subgraph has more than one node and the last node is unsupported. No "split_after" graph.
442-
// Note due to potential shape tensor inputs, cannot mark the first subgraph as supported here.
443-
else if (node_index == subgraph.size() - 1)
444-
{
445-
NodesContainer_t split_before (subgraph.begin(), subgraph.begin() + node_index);
446-
sub_graph_collection[graph_index].first = split_before;
447-
}
448-
// Case where unsupported node is somewhere in the middle. Split the subgraph at that point into two.
449-
// Note due to potential shape tensor inputs, cannot mark the first subgraph as supported here.
450-
else
451-
{
452-
NodesContainer_t split_before (subgraph.begin(), subgraph.begin() + node_index);
453-
NodesContainer_t split_after (subgraph.begin() + node_index + 1, subgraph.end());
454-
sub_graph_collection[graph_index].first = split_before;
455-
sub_graph_collection.insert(sub_graph_collection.begin() + graph_index + 1, std::make_pair(split_after, false));
456-
}
457-
break;
458-
}
459-
}
460-
}
461-
}
462-
463393
// Only mark the subgraph as supported if there is one supported subgraph.
464394
if (allSupported)
465395
{
@@ -565,7 +495,10 @@ ModelImporter::importModel(::ONNX_NAMESPACE::ModelProto const &model,
565495

566496
if (node.op_type() == "Shape")
567497
{
498+
// Insert the node itself to catch ShapeLayer outputs.
568499
shapeTensors.insert({node.name(), node_idx});
500+
// Shape layers should only have one output
501+
shapeTensors.insert({node.output(0), node_idx});
569502
}
570503

571504
for( size_t i=0; i<outputs.size(); ++i ) {
@@ -602,6 +535,7 @@ ModelImporter::importModel(::ONNX_NAMESPACE::ModelProto const &model,
602535
}
603536
nvinfer1::ITensor** user_output = _importer_ctx.getUserOutput(output.name().c_str());
604537
if( !user_output ) {
538+
// Sanity check that the output is not the output of a shape layer
605539
auto outputName = output_tensor_ptr->getName();
606540
if (shapeTensors.count(outputName))
607541
{
@@ -630,6 +564,25 @@ ModelImporter::importModel(::ONNX_NAMESPACE::ModelProto const &model,
630564
ASSERT(user_output.is_tensor(), ErrorCode::kINVALID_VALUE);
631565
*user_output_ptr = &user_output.tensor();
632566
}
567+
568+
// Do a sanity check that all ShapeLayer outputs are being used as shape tensors.
569+
auto* network = _importer_ctx.network();
570+
571+
for (int i = 0; i < network->getNbLayers(); i++)
572+
{
573+
auto* layer = network->getLayer(i);
574+
for (int j = 0; j < layer->getNbInputs(); j++)
575+
{
576+
auto* tensor = layer->getInput(j);
577+
auto tensorName = tensor->getName();
578+
if (shapeTensors.count(tensorName) && !tensor->isShapeTensor())
579+
{
580+
_current_node = shapeTensors.at(tensorName);
581+
ASSERT(false && "Shape layer outputs must be used as a shape tensor!" , ErrorCode::kUNSUPPORTED_GRAPH);
582+
}
583+
}
584+
}
585+
633586
return Status::success();
634587
}
635588

0 commit comments

Comments
 (0)