@@ -125,9 +125,10 @@ def __init__(self, config_entry, model_name='', delayed_model_loading=False,
125
125
self .get_value_from_config ('weights' ),
126
126
self .get_value_from_config ('_model_type' )
127
127
)
128
- self .load_network (log = True , preprocessing = preprocessor )
128
+ self .load_network (log = postpone_inputs_configuration , preprocessing = preprocessor )
129
129
self .allow_reshape_input = self .get_value_from_config ('allow_reshape_input' ) and self .network is not None
130
- self .try_to_set_default_layout ()
130
+ if postpone_inputs_configuration :
131
+ self .try_to_set_default_layout ()
131
132
else :
132
133
self .allow_reshape_input = self .get_value_from_config ('allow_reshape_input' )
133
134
self ._target_layout_mapping = {}
@@ -511,19 +512,15 @@ def _create_network(self, input_shapes=None):
511
512
self .network = None
512
513
self .exec_network = self .ie_core .import_model (str (self ._model ), self ._device )
513
514
self .original_outputs = self .exec_network .outputs
514
- ie_input_info = self .exec_network .inputs
515
- input_info = ie_input_info [0 ]
516
- batch_pos = (
517
- input_info .get_node ().layout .get_index_by_name ('N' )
518
- if input_info .get_node ().layout .has_name ('N' ) else - 1
519
- )
520
-
521
- self ._batch = parse_partial_shape (input_info .partial_shape )[batch_pos ] if batch_pos != - 1 else 1
515
+ model_batch = self ._get_model_batch_size ()
516
+ self ._batch = model_batch if model_batch is not None else 1
522
517
return
523
518
if self ._weights is None and self ._model .suffix != '.onnx' :
524
519
self ._weights = model_path .parent / (model_path .name .split (model_path .suffix )[0 ] + '.bin' )
525
520
self .network = self .read_network (self ._model , self ._weights )
526
521
self .original_outputs = self .network .outputs
522
+ model_batch = self ._get_model_batch_size ()
523
+ model_batch = 1 if model_batch is None else model_batch
527
524
outputs = self .config .get ('outputs' )
528
525
if outputs :
529
526
def output_preprocessing (output_string ):
@@ -536,13 +533,53 @@ def output_preprocessing(output_string):
536
533
self .network .add_outputs (preprocessed_outputs )
537
534
if input_shapes is not None :
538
535
self .network .reshape (input_shapes )
539
- self ._batch = self .config .get ('batch' , 1 )
536
+ self ._batch = self .config .get ('batch' , model_batch )
537
+ self ._set_batch_size (self ._batch )
540
538
affinity_map_path = self .config .get ('affinity_map' )
541
539
if affinity_map_path and self ._is_hetero ():
542
540
self ._set_affinity (affinity_map_path )
543
541
elif affinity_map_path :
544
542
warning ('affinity_map config is applicable only for HETERO device' )
545
543
544
+ def _set_batch_size (self , batch_size ):
545
+ model_batch_size = self ._get_model_batch_size ()
546
+ model_batch_size = 1 if model_batch_size is None else model_batch_size
547
+ if batch_size is None :
548
+ batch_size = model_batch_size
549
+ if batch_size == model_batch_size :
550
+ self ._batch = batch_size
551
+ return
552
+ input_shapes = {}
553
+ for input_node in self .network .inputs :
554
+ layer_name = input_node .get_node ().friendly_name
555
+ if layer_name in self .const_inputs :
556
+ input_shapes [layer_name ] = parse_partial_shape (layer_name .partial_shape )
557
+ else :
558
+ layer_shape = parse_partial_shape (layer_name .partial_shape )
559
+ layout = self .inputs [layer_name ].layout
560
+ if '...' in str (layout ):
561
+ layout = self .get_layout_from_config (layer_name )
562
+ else :
563
+ layout = str (layout ).replace ('[' , '' ).replace (']' , '' ).replace (',' , '' )
564
+ batch_pos = layout .find ('N' )
565
+ if batch_pos != - 1 :
566
+ layer_shape [batch_pos ] = batch_size
567
+ input_shapes [layer_name ] = layer_shape
568
+ self ._reshape_input (input_shapes , batch_size == - 1 )
569
+ self ._batch = batch_size
570
+
571
+ def _get_model_batch_size (self ):
572
+ input_nodes = self .network .inputs if self .network else self .exec_network .inputs
573
+ input_info = input_nodes [0 ]
574
+ if '...' in str (input_info .get_node ().layout ):
575
+ layout = self .get_layout_from_config (input_info .get_node ().friendly_name )
576
+ else :
577
+ layout = str (input_info .get_node ().layout ).replace ('[' , '' ).replace (']' , '' ).replace (',' , '' )
578
+ batch_pos = layout .find ('N' )
579
+ if batch_pos != - 1 :
580
+ return parse_partial_shape (input_info .partial_shape )[batch_pos ]
581
+ return None
582
+
546
583
def load_network (self , network = None , log = False , preprocessing = None ):
547
584
if hasattr (self , 'exec_network' ):
548
585
del self .exec_network
@@ -573,6 +610,7 @@ def update_input_configuration(self, input_config):
573
610
self ._set_precision ()
574
611
self ._set_input_shape ()
575
612
self .try_to_set_default_layout ()
613
+ self ._set_batch_size (self .config .get ('batch' ))
576
614
self .dyn_input_layers , self ._partial_shapes = self .get_dynamic_inputs (self .network )
577
615
self .print_input_output_info (self .network if self .network is not None else self .exec_network )
578
616
if self .preprocessor :
0 commit comments