@@ -547,27 +547,29 @@ def _set_batch_size(self, batch_size):
547
547
if layer_name in self .const_inputs :
548
548
input_shapes [layer_name ] = parse_partial_shape (input_node .get_node ().partial_shape )
549
549
else :
550
- layer_shape = parse_partial_shape (input_node .get_node ().partial_shape )
551
- layout = self .inputs [layer_name ].layout
552
- if '...' in str (layout ):
553
- layout = self .get_layout_from_config (layer_name )
554
- else :
555
- layout = str (layout ).replace ('[' , '' ).replace (']' , '' ).replace (',' , '' )
550
+ layer_shape = list (parse_partial_shape (input_node .get_node ().partial_shape ))
551
+ layout = self ._process_layout (self .inputs [layer_name ].layout , layer_name )
556
552
batch_pos = layout .find ('N' )
557
553
if batch_pos != - 1 :
558
554
layer_shape [batch_pos ] = batch_size
559
555
input_shapes [layer_name ] = layer_shape
560
556
self ._reshape_input (input_shapes , batch_size == - 1 )
561
557
self ._batch = batch_size
562
558
559
+ def _process_layout (self , ov_layout , layer_name ):
560
+ if '...' in str (ov_layout ) or ov_layout is None :
561
+ ov_layout = self .get_layout_from_config (layer_name )
562
+ else :
563
+ ov_layout = str (ov_layout ).replace ('[' , '' ).replace (']' , '' ).replace (',' , '' )
564
+ return ov_layout
565
+
563
566
def _get_model_batch_size (self ):
564
567
input_nodes = self .network .inputs if self .network else self .exec_network .inputs
565
568
input_info = input_nodes [0 ]
566
- layout = input_info .get_node ().layout
567
- if '...' in str (layout ) or layout is None :
568
- layout = self .get_layout_from_config (input_info .get_node ().friendly_name )
569
- else :
570
- layout = str (layout ).replace ('[' , '' ).replace (']' , '' ).replace (',' , '' )
569
+ layout = (
570
+ self ._process_layout (input_info .get_node ().layout , input_info .get_node ().friendly_name )
571
+ or self .default_layout
572
+ )
571
573
batch_pos = layout .find ('N' )
572
574
if batch_pos != - 1 :
573
575
return parse_partial_shape (input_info .partial_shape )[batch_pos ]
@@ -702,6 +704,8 @@ def read_network(self, model, weights):
702
704
network = self .ie_core .read_model (model = str (model ), weights = str (weights ))
703
705
else :
704
706
network = self .ie_core .read_model (model = str (model ))
707
+ self .input_to_tensor_name = self .get_input_tensor_name_mapping (network )
708
+ self .input_to_index = {inp .get_node ().friendly_name : idx for idx , inp in enumerate (network .inputs )}
705
709
return network
706
710
707
711
def inputs_info_for_meta (self , inputs = None ):
0 commit comments