38
38
string_to_tuple ,
39
39
get_or_parse_value ,
40
40
parse_partial_shape ,
41
+ postprocess_output_name
41
42
)
42
43
from .launcher import Launcher
43
44
from ..logging import print_info
@@ -330,7 +331,7 @@ def _reshape_input(self, shapes, make_dynamic=False):
330
331
p_shape = PartialShape (
331
332
[Dimension (d ) if not isinstance (d , tuple ) else Dimension (d [0 ], d [1 ]) for d in shape ]
332
333
)
333
- partial_shapes [self .input_to_tensor_name [name ]] = p_shape
334
+ partial_shapes [self .input_to_index [name ]] = p_shape
334
335
335
336
self .network .reshape (partial_shapes )
336
337
self .dyn_input_layers , self ._partial_shapes = self .get_dynamic_inputs (self .network )
@@ -548,6 +549,7 @@ def load_network(self, network=None, log=False, preprocessing=None):
548
549
if self .network is not None :
549
550
self .dyn_input_layers , self ._partial_shapes = self .get_dynamic_inputs (self .network )
550
551
self .input_to_tensor_name = self .get_input_tensor_name_mapping (self .network )
552
+ self .input_to_index = {inp .get_node ().friendly_name : idx for idx , inp in enumerate (self .network .inputs )}
551
553
552
554
if not self ._postpone_input_configuration :
553
555
self ._set_precision ()
@@ -813,8 +815,8 @@ def _fill_lstm_inputs(self, infer_outputs=None):
813
815
feed_dict = {}
814
816
for lstm_var , output_layer in self ._lstm_inputs .items ():
815
817
layer_shape = parse_partial_shape (self .inputs [lstm_var ].partial_shape )
816
- if infer_outputs and output_layer not in infer_outputs :
817
- raise ConfigError ( 'Output node with name {} not found' . format (output_layer ) )
818
+ if infer_outputs :
819
+ output_layer = postprocess_output_name (output_layer , infer_outputs )
818
820
input_data = infer_outputs [output_layer ].reshape (layer_shape ) if infer_outputs else np .zeros (
819
821
layer_shape , dtype = format_map [self .inputs [lstm_var ].element_type .get_type_name ()]
820
822
)
@@ -831,13 +833,15 @@ def print_input_output_info(network, prefix=None):
831
833
network_outputs = network .outputs
832
834
for input_info in network_inputs :
833
835
input_node = input_info .get_node ()
834
- print_info ('\t Layer name: {}' .format (input_node .friendly_name ))
836
+ print_info ('\t Node name: {}' .format (input_node .friendly_name ))
837
+ print_info ('\t Tensor names: {}' .format (', ' .join (input_info .get_names ())))
835
838
print_info ('\t precision: {}' .format (input_node .element_type .get_type_name ()))
836
839
print_info ('\t shape: {}\n ' .format (parse_partial_shape (input_node .get_partial_shape ())))
837
840
print_info ('Output info' )
838
841
for output_info in network_outputs :
839
842
out_node = output_info .get_node ()
840
- print_info ('\t Layer name: {}' .format (out_node .friendly_name ))
843
+ print_info ('\t Node name: {}' .format (out_node .friendly_name ))
844
+ print_info ('\t Tensor names: {}' .format (', ' .join (output_info .get_names ())))
841
845
precision = out_node .get_output_element_type (0 ).get_type_name ()
842
846
print_info ('\t precision: {}' .format (precision ))
843
847
shape = parse_partial_shape (out_node .get_output_partial_shape (0 ))
0 commit comments