2020import efficientdet_arch as legacy_arch
2121import hparams_config
2222from tf2 import efficientdet_keras
23+ from tf2 import train_lib
2324
2425SEED = 111111
2526
@@ -37,7 +38,6 @@ def test_model_output(self):
3738 inputs_shape = [1 , 512 , 512 , 3 ]
3839 config = hparams_config .get_efficientdet_config ('efficientdet-d0' )
3940 config .heads = ['object_detection' , 'segmentation' ]
40- tmp_ckpt = os .path .join (tempfile .mkdtemp (), 'ckpt' )
4141 with tf .Session (graph = tf .Graph ()) as sess :
4242 feats = tf .ones (inputs_shape )
4343 tf .random .set_random_seed (SEED )
@@ -48,7 +48,6 @@ def test_model_output(self):
4848 grads = tf .nest .map_structure (lambda output : tf .gradients (output , feats ),
4949 outputs )
5050 keras_class_grads , keras_box_grads , _ = sess .run (grads )
51- model .save_weights (tmp_ckpt )
5251 with tf .Session (graph = tf .Graph ()) as sess :
5352 feats = tf .ones (inputs_shape )
5453 tf .random .set_random_seed (SEED )
@@ -60,41 +59,57 @@ def test_model_output(self):
6059 legacy_class_grads , legacy_box_grads = sess .run (grads )
6160
6261 for i in range (3 , 8 ):
63- self .assertAllClose (
64- keras_class_out [i - 3 ], legacy_class_out [i ], rtol = 1e-4 , atol = 1e-4 )
65- self .assertAllClose (
66- keras_box_out [i - 3 ], legacy_box_out [i ], rtol = 1e-4 , atol = 1e-4 )
67- self .assertAllClose (
68- keras_class_grads [i - 3 ], legacy_class_grads [i ], rtol = 1e-4 , atol = 1e-4 )
69- self .assertAllClose (
70- keras_box_grads [i - 3 ], legacy_box_grads [i ], rtol = 1e-4 , atol = 1e-4 )
62+ self .assertAllEqual (
63+ keras_class_out [i - 3 ], legacy_class_out [i ])
64+ self .assertAllEqual (
65+ keras_box_out [i - 3 ], legacy_box_out [i ])
66+ self .assertAllEqual (
67+ keras_class_grads [i - 3 ], legacy_class_grads [i ])
68+ self .assertAllEqual (
69+ keras_box_grads [i - 3 ], legacy_box_grads [i ])
7170
7271 def test_eager_output (self ):
7372 inputs_shape = [1 , 512 , 512 , 3 ]
7473 config = hparams_config .get_efficientdet_config ('efficientdet-d0' )
75- config .heads = ['object_detection' , 'segmentation' ]
74+ config .heads = ['object_detection' ]
7675 tmp_ckpt = os .path .join (tempfile .mkdtemp (), 'ckpt2' )
7776
7877 with tf .Session (graph = tf .Graph ()) as sess :
7978 feats = tf .ones (inputs_shape )
8079 tf .random .set_random_seed (SEED )
8180 model = efficientdet_keras .EfficientDetNet (config = config )
8281 outputs = model (feats , True )
82+ grads = tf .nest .map_structure (lambda output : tf .gradients (output , feats ),
83+ outputs )
8384 sess .run (tf .global_variables_initializer ())
84- keras_class_out , keras_box_out , keras_seg_out = sess .run (outputs )
85+ keras_class_out , keras_box_out = sess .run (outputs )
86+ legacy_class_grads , legacy_box_grads = sess .run (grads )
8587 model .save_weights (tmp_ckpt )
8688
8789 feats = tf .ones (inputs_shape )
8890 model = efficientdet_keras .EfficientDetNet (config = config )
91+ model .build (inputs_shape )
8992 model .load_weights (tmp_ckpt )
90- eager_class_out , eager_box_out , eager_seg_out = model (feats , True )
93+
94+ @tf .function
95+ def _run (feats ):
96+ with tf .GradientTape (persistent = True ) as tape :
97+ tape .watch (feats )
98+ eager_class_out , eager_box_out = model (feats , True )
99+ class_grads , box_grads = tf .nest .map_structure (
100+ lambda output : tape .gradient (output , feats ),
101+ [eager_class_out , eager_box_out ])
102+ return eager_class_out , eager_box_out , class_grads , box_grads
103+ eager_class_out , eager_box_out , class_grads , box_grads = _run (feats )
91104 for i in range (5 ):
92- self .assertAllClose (
93- eager_class_out [i ], keras_class_out [i ], rtol = 1e-4 , atol = 1e-4 )
94- self .assertAllClose (
95- eager_box_out [i ], keras_box_out [i ], rtol = 1e-4 , atol = 1e-4 )
96- self .assertAllClose (
97- eager_seg_out , keras_seg_out , rtol = 1e-4 , atol = 1e-4 )
105+ self .assertAllEqual (
106+ eager_class_out [i ], keras_class_out [i ])
107+ self .assertAllEqual (
108+ eager_box_out [i ], keras_box_out [i ])
109+ self .assertAllEqual (
110+ class_grads [i ], legacy_class_grads [i ][0 ])
111+ self .assertAllEqual (
112+ box_grads [i ], legacy_box_grads [i ][0 ])
98113
99114 def test_build_feature_network (self ):
100115 config = hparams_config .get_efficientdet_config ('efficientdet-d0' )
@@ -130,8 +145,8 @@ def test_build_feature_network(self):
130145 legacy_grads = sess .run (grads [3 :6 ])
131146
132147 for i in range (config .min_level , config .max_level + 1 ):
133- self .assertAllClose (keras_feats [i - config .min_level ], legacy_feats [i ])
134- self .assertAllClose (keras_grads [i - config .min_level ],
148+ self .assertAllEqual (keras_feats [i - config .min_level ], legacy_feats [i ])
149+ self .assertAllEqual (keras_grads [i - config .min_level ],
135150 legacy_grads [i - config .min_level ])
136151
137152 def test_model_variables (self ):
@@ -192,6 +207,23 @@ def test_resample_feature_map(self):
192207 actual_result = resample_layer (feat , training , all_feats )
193208 self .assertAllCloseAccordingToType (expect_result , actual_result )
194209
210+ def test_hub_model (self ):
211+ image = tf .random .uniform ((1 , 320 , 320 , 3 ))
212+ keras_model = efficientdet_keras .EfficientDetNet ('efficientdet-lite0' )
213+ tmp_ckpt = os .path .join (tempfile .mkdtemp (), 'ckpt' )
214+ keras_model .config .model_dir = tmp_ckpt
215+ base_model = train_lib .EfficientDetNetTrainHub (keras_model .config ,
216+ "https://tfhub.dev/tensorflow/efficientdet/lite0/feature-vector/1" )
217+ cls_outputs , box_outputs = tf .function (base_model )(image , training = False )
218+ keras_model .build (image .shape )
219+ d1 = {var .name : var for var in base_model .variables }
220+ for var in keras_model .variables :
221+ var .assign (d1 [var .name ].numpy ())
222+ cls_outputs2 , box_outputs2 = tf .function (keras_model )(image , False )
223+ for c1 , b1 , c2 , b2 in zip (cls_outputs , box_outputs , cls_outputs2 , box_outputs2 ):
224+ self .assertAllEqual (c1 , c2 )
225+ self .assertAllEqual (b1 , b2 )
226+
195227 def test_resample_var_names (self ):
196228 with tf .Graph ().as_default ():
197229 feat = tf .random .uniform ([1 , 16 , 16 , 320 ])
0 commit comments