2525from backbone import efficientnet_builder
2626from tf2 import fpn_configs
2727from tf2 import postprocess
28- from tf2 import tfmot
2928from tf2 import util_keras
3029
3130
@@ -56,7 +55,6 @@ def __init__(self,
5655 strategy ,
5756 weight_method ,
5857 data_format ,
59- model_optimizations ,
6058 name = 'fnode' ):
6159 super ().__init__ (name = name )
6260 self .feat_level = feat_level
@@ -73,7 +71,6 @@ def __init__(self,
7371 self .conv_bn_act_pattern = conv_bn_act_pattern
7472 self .resample_layers = []
7573 self .vars = []
76- self .model_optimizations = model_optimizations
7774
7875 def fuse_features (self , nodes ):
7976 """Fuse features from different resolutions and return a weighted sum.
@@ -141,7 +138,6 @@ def build(self, feats_shape):
141138 self .conv_after_downsample ,
142139 strategy = self .strategy ,
143140 data_format = self .data_format ,
144- model_optimizations = self .model_optimizations ,
145141 name = name ))
146142 if self .weight_method == 'attn' :
147143 self ._add_wsm ('ones' )
@@ -161,7 +157,6 @@ def build(self, feats_shape):
161157 self .act_type ,
162158 self .data_format ,
163159 self .strategy ,
164- self .model_optimizations ,
165160 name = 'op_after_combine{}' .format (len (feats_shape )))
166161 self .built = True
167162 super ().build (feats_shape )
@@ -188,7 +183,6 @@ def __init__(self,
188183 act_type ,
189184 data_format ,
190185 strategy ,
191- model_optimizations ,
192186 name = 'op_after_combine' ):
193187 super ().__init__ (name = name )
194188 self .conv_bn_act_pattern = conv_bn_act_pattern
@@ -211,10 +205,6 @@ def __init__(self,
211205 use_bias = not self .conv_bn_act_pattern ,
212206 data_format = self .data_format ,
213207 name = 'conv' )
214- if model_optimizations :
215- for method in model_optimizations .keys ():
216- self .conv_op = (
217- tfmot .get_method (method )(self .conv_op ))
218208 self .bn = util_keras .build_batch_norm (
219209 is_training_bn = self .is_training_bn ,
220210 data_format = self .data_format ,
@@ -244,7 +234,6 @@ def __init__(self,
244234 data_format = None ,
245235 pooling_type = None ,
246236 upsampling_type = None ,
247- model_optimizations = None ,
248237 name = 'resample_p0' ):
249238 super ().__init__ (name = name )
250239 self .apply_bn = apply_bn
@@ -262,9 +251,6 @@ def __init__(self,
262251 padding = 'same' ,
263252 data_format = self .data_format ,
264253 name = 'conv2d' )
265- if model_optimizations :
266- for method in model_optimizations .keys ():
267- self .conv2d = tfmot .get_method (method )(self .conv2d )
268254 self .bn = util_keras .build_batch_norm (
269255 is_training_bn = self .is_training_bn ,
270256 data_format = self .data_format ,
@@ -291,14 +277,14 @@ def _pool2d(self, inputs, height, width, target_height, target_width):
291277
292278 def _upsample2d (self , inputs , target_height , target_width ):
293279 if self .data_format == 'channels_first' :
294- inputs = tf .compat . v1 . transpose (inputs , perm = [0 , 2 , 3 , 1 ])
295- outputs = tf .cast (
280+ inputs = tf .transpose (inputs , [0 , 2 , 3 , 1 ])
281+ resized = tf .cast (
296282 tf .compat .v1 .image .resize_nearest_neighbor (
297283 tf .cast (inputs , tf .float32 ), [target_height , target_width ]),
298284 inputs .dtype )
299285 if self .data_format == 'channels_first' :
300- outputs = tf .compat . v1 . transpose (outputs , perm = [0 , 3 , 1 , 2 ])
301- return outputs
286+ resized = tf .transpose (resized , [0 , 3 , 1 , 2 ])
287+ return resized
302288
303289 def _maybe_apply_1x1 (self , feat , training , num_channels ):
304290 """Apply 1x1 conv to change layer width if necessary."""
@@ -428,15 +414,14 @@ def __init__(self,
428414 def _conv_bn_act (self , image , i , level_id , training ):
429415 conv_op = self .conv_ops [i ]
430416 bn = self .bns [i ][level_id ]
431- act_type = self .act_type
432417
433418 @utils .recompute_grad (self .grad_checkpoint )
434419 def _call (image ):
435420 original_image = image
436421 image = conv_op (image )
437422 image = bn (image , training = training )
438423 if self .act_type :
439- image = utils .activation_fn (image , act_type )
424+ image = utils .activation_fn (image , self . act_type )
440425 if i > 0 and self .survival_prob :
441426 image = utils .drop_connect (image , training , self .survival_prob )
442427 image = image + original_image
@@ -590,15 +575,14 @@ def __init__(self,
590575 def _conv_bn_act (self , image , i , level_id , training ):
591576 conv_op = self .conv_ops [i ]
592577 bn = self .bns [i ][level_id ]
593- act_type = self .act_type
594578
595579 @utils .recompute_grad (self .grad_checkpoint )
596580 def _call (image ):
597581 original_image = image
598582 image = conv_op (image )
599583 image = bn (image , training = training )
600584 if self .act_type :
601- image = utils .activation_fn (image , act_type )
585+ image = utils .activation_fn (image , self . act_type )
602586 if i > 0 and self .survival_prob :
603587 image = utils .drop_connect (image , training , self .survival_prob )
604588 image = image + original_image
@@ -754,6 +738,7 @@ class FPNCell(tf.keras.layers.Layer):
754738
755739 def __init__ (self , config , name = 'fpn_cell' ):
756740 super ().__init__ (name = name )
741+ logging .info ('building FPNCell %s' , name )
757742 self .config = config
758743 if config .fpn_config :
759744 self .fpn_config = config .fpn_config
@@ -778,7 +763,6 @@ def __init__(self, config, name='fpn_cell'):
778763 strategy = config .strategy ,
779764 weight_method = self .fpn_config .weight_method ,
780765 data_format = config .data_format ,
781- model_optimizations = config .model_optimizations ,
782766 name = 'fnode%d' % i )
783767 self .fnodes .append (fnode )
784768
@@ -839,7 +823,6 @@ def __init__(self,
839823 conv_after_downsample = config .conv_after_downsample ,
840824 strategy = config .strategy ,
841825 data_format = config .data_format ,
842- model_optimizations = config .model_optimizations ,
843826 name = 'resample_p%d' % level ,
844827 ))
845828 self .fpn_cells = FPNCells (config )
@@ -953,7 +936,7 @@ def map_fn(image):
953936 if raw_images .shape .as_list ()[0 ]: # fixed batch size.
954937 batch_size = raw_images .shape .as_list ()[0 ]
955938 outputs = [map_fn (raw_images [i ]) for i in range (batch_size )]
956- return [tf .stack (y ) for y in zip (* outputs )]
939+ return [tf .stop_gradient ( tf . stack (y ) ) for y in zip (* outputs )]
957940
958941 # otherwise treat it as dynamic batch size.
959942 return tf .vectorized_map (map_fn , raw_images )
@@ -999,6 +982,8 @@ def call(self, inputs, training=False, pre_mode='infer', post_mode='global'):
999982 config .mean_rgb , config .stddev_rgb ,
1000983 pre_mode )
1001984 # network.
985+ if config .data_format == 'channels_first' :
986+ inputs = tf .transpose (inputs , [0 , 3 , 1 , 2 ])
1002987 outputs = super ().call (inputs , training )
1003988
1004989 if 'object_detection' in config .heads and post_mode :
0 commit comments