1717import tempfile
1818from absl import logging
1919import tensorflow .compat .v1 as tf
20+ import tensorflow .compat .v2 as tf2
2021import efficientdet_arch as legacy_arch
2122import hparams_config
2223from tf2 import efficientdet_keras
@@ -38,26 +39,37 @@ def test_model_output(self):
3839 inputs_shape = [1 , 512 , 512 , 3 ]
3940 config = hparams_config .get_efficientdet_config ('efficientdet-d0' )
4041 config .heads = ['object_detection' , 'segmentation' ]
42+ tf2 .keras .utils .set_random_seed (SEED )
4143 with tf .Session (graph = tf .Graph ()) as sess :
4244 feats = tf .ones (inputs_shape )
43- tf .random .set_random_seed (SEED )
4445 model = efficientdet_keras .EfficientDetNet (config = config )
4546 outputs = model (feats , True )
4647 sess .run (tf .global_variables_initializer ())
4748 keras_class_out , keras_box_out , _ = sess .run (outputs )
4849 grads = tf .nest .map_structure (lambda output : tf .gradients (output , feats ),
4950 outputs )
5051 keras_class_grads , keras_box_grads , _ = sess .run (grads )
52+ vars = list (filter (
53+ lambda var : not var .name .startswith ('segmentation' ),
54+ tf .global_variables ()))
55+ vars .sort (key = lambda var : var .name )
56+ keras_vars_names = [var .name for var in vars ]
57+ keras_vars_values = sess .run (vars )
58+
5159 with tf .Session (graph = tf .Graph ()) as sess :
5260 feats = tf .ones (inputs_shape )
53- tf .random .set_random_seed (SEED )
5461 outputs = legacy_arch .efficientdet (feats , config = config )
55- sess .run (tf .global_variables_initializer ())
62+ vars = tf .global_variables ()
63+ vars .sort (key = lambda var : var .name )
64+ legacy_vars_names = [var .name for var in vars ]
65+ sess .run ([var .assign (val ) for val , var in zip (keras_vars_values , vars )])
5666 legacy_class_out , legacy_box_out = sess .run (outputs )
5767 grads = tf .nest .map_structure (lambda output : tf .gradients (output , feats ),
5868 outputs )
5969 legacy_class_grads , legacy_box_grads = sess .run (grads )
6070
71+ self .assertAllEqual (keras_vars_names , legacy_vars_names )
72+
6173 for i in range (3 , 8 ):
6274 self .assertAllEqual (
6375 keras_class_out [i - 3 ], legacy_class_out [i ])
@@ -76,7 +88,7 @@ def test_eager_output(self):
7688
7789 with tf .Session (graph = tf .Graph ()) as sess :
7890 feats = tf .ones (inputs_shape )
79- tf . random .set_random_seed (SEED )
91+ tf2 . keras . utils .set_random_seed (SEED )
8092 model = efficientdet_keras .EfficientDetNet (config = config )
8193 outputs = model (feats , True )
8294 grads = tf .nest .map_structure (lambda output : tf .gradients (output , feats ),
@@ -120,7 +132,7 @@ def test_build_feature_network(self):
120132 tf .ones ([1 , 32 , 32 , 112 ]), # level 4
121133 tf .ones ([1 , 16 , 16 , 320 ]), # level 5
122134 ]
123- tf . random .set_random_seed (SEED )
135+ tf2 . keras . utils .set_random_seed (SEED )
124136 fpn_cell = efficientdet_keras .FPNCells (config )
125137 new_feats1 = fpn_cell (inputs , True )
126138 sess .run (tf .global_variables_initializer ())
@@ -137,7 +149,7 @@ def test_build_feature_network(self):
137149 4 : tf .ones ([1 , 32 , 32 , 112 ]),
138150 5 : tf .ones ([1 , 16 , 16 , 320 ])
139151 }
140- tf . random .set_random_seed (SEED )
152+ tf2 . keras . utils .set_random_seed (SEED )
141153 new_feats2 = legacy_arch .build_feature_network (inputs , config )
142154 sess .run (tf .global_variables_initializer ())
143155 legacy_feats = sess .run (new_feats2 )
@@ -185,7 +197,7 @@ def test_resample_feature_map(self):
185197 for strategy in ['tpu' , '' ]:
186198 with self .subTest (
187199 apply_bn = apply_bn , training = training , strategy = strategy ):
188- tf . random .set_random_seed (SEED )
200+ tf2 . keras . utils .set_random_seed (SEED )
189201 expect_result = legacy_arch .resample_feature_map (
190202 feat ,
191203 name = 'resample_p0' ,
@@ -195,7 +207,7 @@ def test_resample_feature_map(self):
195207 apply_bn = apply_bn ,
196208 is_training = training ,
197209 strategy = strategy )
198- tf . random .set_random_seed (SEED )
210+ tf2 . keras . utils .set_random_seed (SEED )
199211 resample_layer = efficientdet_keras .ResampleFeatureMap (
200212 name = 'resample_p0' ,
201213 feat_level = 0 ,
0 commit comments