@@ -2401,6 +2401,15 @@ def __init__(self, name, inputs, axis, offset, shape, **xargs):
2401
2401
image_conf .channels = input_layer .size / (input_layer .width *
2402
2402
input_layer .height )
2403
2403
2404
+ if (len (self .config .inputs ) == 2 ):
2405
+ self .set_layer_height_width (
2406
+ self .get_input_layer (1 ).height , self .get_input_layer (1 ).width )
2407
+ self .set_layer_size (self .get_input_layer (1 ).size )
2408
+ else :
2409
+ # NCHW order
2410
+ self .set_layer_height_width (shape [- 2 ], shape [- 1 ])
2411
+ self .set_layer_size (reduce (lambda x , y : x * y , shape ))
2412
+
2404
2413
2405
2414
@config_layer ('batch_norm' )
2406
2415
class BatchNormLayer (LayerBase ):
@@ -3850,6 +3859,16 @@ def __init__(self, name, inputs, reshape, **xargs):
3850
3859
name , 'switch_order' , 0 , inputs = inputs , ** xargs )
3851
3860
self .config .reshape_conf .height_axis .extend (reshape ['height' ])
3852
3861
self .config .reshape_conf .width_axis .extend (reshape ['width' ])
3862
+ input_layer = self .get_input_layer (0 )
3863
+ if reshape is None :
3864
+ self .set_layer_size (input_layer .size )
3865
+ else :
3866
+ inH = input_layer .height
3867
+ inW = input_layer .width
3868
+ inC = input_layer .size / inH / inW
3869
+ out_dims = [0 , inH , inW , inC ]
3870
+ size = reduce (lambda x , y : x * y , out_dims [reshape ['width' ][0 ]:])
3871
+ self .set_layer_size (size )
3853
3872
3854
3873
3855
3874
@config_layer ('scale_sub_region' )
0 commit comments