@@ -2400,6 +2400,14 @@ def __init__(self, name, inputs, axis, offset, shape, **xargs):
2400
2400
image_conf .img_size_y = input_layer .height
2401
2401
image_conf .channels = input_layer .size / (input_layer .width *
2402
2402
input_layer .height )
2403
+ # only support for 4-dims inputs and NCHW order
2404
+ if (len (self .config .inputs ) == 2 ):
2405
+ self .set_layer_height_width (
2406
+ self .get_input_layer (1 ).height , self .get_input_layer (1 ).width )
2407
+ self .set_layer_size (self .get_input_layer (1 ).size )
2408
+ else :
2409
+ self .set_layer_height_width (shape [- 2 ], shape [- 1 ])
2410
+ self .set_layer_size (reduce (lambda x , y : x * y , shape [1 :]))
2403
2411
2404
2412
2405
2413
@config_layer ('batch_norm' )
@@ -3849,6 +3857,26 @@ def __init__(self, name, inputs, reshape, **xargs):
3849
3857
name , 'switch_order' , 0 , inputs = inputs , ** xargs )
3850
3858
self .config .reshape_conf .height_axis .extend (reshape ['height' ])
3851
3859
self .config .reshape_conf .width_axis .extend (reshape ['width' ])
3860
+ input_layer = self .get_input_layer (0 )
3861
+ if reshape is None :
3862
+ self .set_layer_size (input_layer .size )
3863
+ else :
3864
+ in_h = input_layer .height
3865
+ in_w = input_layer .width
3866
+ out_dims = None
3867
+ if input_layer .has_depth ():
3868
+ in_d = input_layer .depth
3869
+ in_c = input_layer .size / in_h / in_w / in_d
3870
+ # batch_size, depth, height, width, channel
3871
+ out_dims = [0 , in_d , in_h , in_w , in_c ]
3872
+ else :
3873
+ in_c = input_layer .size / in_h / in_w
3874
+ # batch_size, height, width, channel
3875
+ out_dims = [0 , in_h , in_w , in_c ]
3876
+ # Because (reshape['width'][0] > 0) always be true.
3877
+ # So out_dims[0] won't be used.
3878
+ size = reduce (lambda x , y : x * y , out_dims [reshape ['width' ][0 ]:])
3879
+ self .set_layer_size (size )
3852
3880
3853
3881
3854
3882
@config_layer ('scale_sub_region' )
0 commit comments