@@ -376,29 +376,30 @@ std::string TensorRtSubgraphPass::CreateTensorRTOp(
376
376
std::vector<int > origin_outputs_dtype;
377
377
std::map<std::string, int > map_origin_outputs_dtype;
378
378
379
- // Whether to mark Outpus
379
+ // Mark TensorRT output nodes as trt outputs
380
380
auto mark_output = Get<bool >(" mark_output" );
381
381
auto output_tensor_name =
382
382
Get<std::vector<std::string>>(" output_tensor_names" );
383
- VLOG ( 1 ) << " mark Output: " << mark_output ;
383
+ auto mark_output_with_id = Get< bool >( " mark_output_with_id " ) ;
384
384
385
- if (mark_output == 1 ) {
385
+ if (mark_output) {
386
386
VLOG (1 ) << " begin to mark output ..." ;
387
387
for (auto node : subgraph) {
388
388
if (node->NodeType () == Node::Type::kOperation ) {
389
- if (node->Op ()->Outputs ().count (" Xshape" )) continue ;
390
389
for (auto *x : node->outputs ) {
391
390
if (std::count (parameters.begin (), parameters.end (), x->Name ()) > 0 )
392
391
continue ;
393
- if (!output_tensor_name.empty () &&
394
- std::count (output_tensor_name.begin (),
395
- output_tensor_name.end (),
396
- x->Name ())) {
397
- VLOG (1 ) << " output " << x->Name () << " has been marked" ;
398
- std::string output_name_withid =
399
- x->Name () + std::to_string (x->id ());
392
+ std::string name_with_id = x->Name () + std::to_string (x->id ());
393
+ if (((!mark_output_with_id && std::count (output_tensor_name.begin (),
394
+ output_tensor_name.end (),
395
+ x->Name ()) > 0 ) ||
396
+ (mark_output_with_id && std::count (output_tensor_name.begin (),
397
+ output_tensor_name.end (),
398
+ name_with_id) > 0 )) &&
399
+ !x->outputs .empty ()) {
400
+ VLOG (3 ) << " output " << x->Name () << " has been marked" ;
400
401
output_names.insert (x->Name ());
401
- output_names_with_id.insert (output_name_withid );
402
+ output_names_with_id.insert (name_with_id );
402
403
origin_name_output_rank[x->Name ()] = x->Var ()->GetShape ().size ();
403
404
trt_outputs.insert (x);
404
405
map_origin_outputs_dtype[x->Name ()] =
0 commit comments