@@ -151,13 +151,17 @@ def try_to_set_default_layout(self):
151
151
shape = parse_partial_shape (input_node .get_node ().partial_shape )
152
152
if len (shape ) != 4 :
153
153
continue
154
+ if input_node .get_node ().layout .has_name ('C' ):
155
+ channel_dim = input_node .get_node ().layout .get_index_by_name ('C' )
156
+ if channel_dim in [3 , - 1 ]:
157
+ self .default_layout = 'NHWC'
158
+ return
154
159
if shape [- 1 ] in [1 , 3 , 4 ]:
155
160
self .default_layout = 'NHWC'
156
161
return
157
162
self .default_layout = 'NCHW'
158
163
return
159
164
160
-
161
165
@property
162
166
def device (self ):
163
167
return self ._device
@@ -500,16 +504,15 @@ def _create_network(self, input_shapes=None):
500
504
if compiled_model :
501
505
self .network = None
502
506
self .exec_network = self .ie_core .import_model (str (self ._model ), self ._device )
503
- self .original_outputs = list (self .exec_network .outputs .keys ())
504
- has_info = hasattr (self .exec_network , 'input_info' )
505
- if has_info :
506
- ie_input_info = {name : data .input_data for name , data in self .exec_network .input_info .items ()}
507
- else :
508
- ie_input_info = self .exec_network .inputs
509
- first_input = next (iter (ie_input_info ))
510
- input_info = ie_input_info [first_input ]
511
- batch_pos = input_info .layout .find ('N' )
512
- self ._batch = input_info .shape [batch_pos ] if batch_pos != - 1 else 1
507
+ self .original_outputs = self .exec_network .outputs
508
+ ie_input_info = self .exec_network .inputs
509
+ input_info = ie_input_info [0 ]
510
+ batch_pos = (
511
+ input_info .get_node ().layout .get_index_by_name ('N' )
512
+ if input_info .get_node ().layout .has_name ('N' ) else - 1
513
+ )
514
+
515
+ self ._batch = parse_partial_shape (input_info .partial_shape )[batch_pos ] if batch_pos != - 1 else 1
513
516
return
514
517
if self ._weights is None and self ._model .suffix != '.onnx' :
515
518
self ._weights = model_path .parent / (model_path .name .split (model_path .suffix )[0 ] + '.bin' )
@@ -611,9 +614,17 @@ def dyn_batch_only(self):
611
614
return True
612
615
for input_name in self .dyn_input_layers :
613
616
partial_shape = self ._partial_shapes [input_name ]
617
+ num_undef = 0
618
+ for i in partial_shape :
619
+ if i == - 1 :
620
+ num_undef += 1
621
+ if num_undef > 1 :
622
+ return False
614
623
layout = self .inputs [input_name ].layout
615
- if str ( layout ) == '[ ...]' :
624
+ if ' ...' in str ( layout ) :
616
625
layout = self .get_layout_from_config (input_name )
626
+ else :
627
+ layout = str (layout ).replace ('[' , '' ).replace (']' , '' ).replace (',' , '' )
617
628
if not layout :
618
629
return False
619
630
for dim , layout_dim in zip (partial_shape , layout ):
@@ -628,6 +639,17 @@ def get_layout_from_config(self, input_name):
628
639
return input_config .get ('layout' , '' )
629
640
return ''
630
641
642
+ @property
643
+ def layout_mapping (self ):
644
+ def prepare_layout_string (layout ):
645
+ layout = str (layout )
646
+ return layout .replace ('[' , '' ).replace (']' , '' ).replace (',' , '' )
647
+ inputs = self .network .inputs if self .network is not None else self .exec_network .inputs
648
+ layouts = {}
649
+ for input_node in inputs :
650
+ layouts [input_node .get_node ().friendly_name ] = prepare_layout_string (input_node .get_node ().layout )
651
+ return layouts
652
+
631
653
def load_ir (self , xml_path , bin_path , log = False ):
632
654
self ._model = xml_path
633
655
self ._weights = bin_path
0 commit comments