@@ -2400,15 +2400,14 @@ def __init__(self, name, inputs, axis, offset, shape, **xargs):
2400
2400
image_conf .img_size_y = input_layer .height
2401
2401
image_conf .channels = input_layer .size / (input_layer .width *
2402
2402
input_layer .height )
2403
-
2403
+ # only support for 4-dims inputs and NCHW order
2404
2404
if (len (self .config .inputs ) == 2 ):
2405
2405
self .set_layer_height_width (
2406
2406
self .get_input_layer (1 ).height , self .get_input_layer (1 ).width )
2407
2407
self .set_layer_size (self .get_input_layer (1 ).size )
2408
2408
else :
2409
- # NCHW order
2410
2409
self .set_layer_height_width (shape [- 2 ], shape [- 1 ])
2411
- self .set_layer_size (reduce (lambda x , y : x * y , shape ))
2410
+ self .set_layer_size (reduce (lambda x , y : x * y , shape [ 1 :] ))
2412
2411
2413
2412
2414
2413
@config_layer ('batch_norm' )
@@ -3865,18 +3864,19 @@ def __init__(self, name, inputs, reshape, **xargs):
3865
3864
else :
3866
3865
in_h = input_layer .height
3867
3866
in_w = input_layer .width
3867
+ out_dims = None
3868
3868
if input_layer .has_depth ():
3869
3869
in_d = input_layer .depth
3870
3870
in_c = input_layer .size / in_h / in_w / in_d
3871
+ # batch_size, depth, height, width, channel
3871
3872
out_dims = [0 , in_d , in_h , in_w , in_c ]
3872
- size = reduce (lambda x , y : x * y ,
3873
- out_dims [reshape ['width' ][0 ]:])
3874
3873
else :
3875
3874
in_c = input_layer .size / in_h / in_w
3875
+ # batch_size, height, width, channel
3876
3876
out_dims = [0 , in_h , in_w , in_c ]
3877
- size = reduce ( lambda x , y : x * y ,
3878
- out_dims [reshape [ 'width' ][ 0 ]:])
3879
-
3877
+ # Because (reshape['width'][0] > 0) always be true.
3878
+ # So out_dims[0] won't be used.
3879
+ size = reduce ( lambda x , y : x * y , out_dims [ reshape [ 'width' ][ 0 ]:])
3880
3880
self .set_layer_size (size )
3881
3881
3882
3882
0 commit comments