@@ -479,17 +479,22 @@ def __init__(self, config, name=None):
479
479
480
480
def set_input_shape (self , input_shape ):
481
481
def is_image_input (shape ):
482
- return len (shape ) == 4 and shape [1 ] in [1 , 3 , 4 ]
482
+ return len (shape ) == 4 and (shape [1 ] in [1 , 3 , 4 ] or shape [- 1 ] in [1 , 3 , 4 ])
483
+
484
+ def is_nhwc (shape ):
485
+ return shape [- 1 ] in [1 , 3 , 4 ]
483
486
if input_shape is None :
484
487
raise ConfigError ('resize to input size impossible' )
485
488
image_inputs = [value for value in input_shape .values () if is_image_input (value )]
486
489
if not image_inputs :
487
490
raise ConfigError ('image input is not detected' )
488
491
if len (image_inputs ) == 1 :
489
- self .dst_height , self .dst_width = image_inputs [0 ][2 :]
492
+ self .dst_height , self .dst_width = (
493
+ image_inputs [0 ][2 :] if not is_nhwc (image_inputs [0 ]) else image_inputs [0 ][1 :3 ]
494
+ )
490
495
else :
491
- self .dst_height = [im_input [2 ] for im_input in image_inputs ]
492
- self .dst_width = [im_input [3 ] for im_input in image_inputs ]
496
+ self .dst_height = [im_input [2 ] if not is_nhwc ( im_input ) else im_input [ 1 ] for im_input in image_inputs ]
497
+ self .dst_width = [im_input [3 ] if not is_nhwc ( im_input ) else im_input [ 2 ] for im_input in image_inputs ]
493
498
494
499
def process (self , image , annotation_meta = None ):
495
500
is_simple_case = not isinstance (image .data , list ) # otherwise -- pyramid, tiling, etc
0 commit comments