Skip to content

Commit 35ca5c7

Browse files
committed
fix test case
1 parent 38ecb93 commit 35ca5c7

File tree

1 file changed

+53
-21
lines changed

1 file changed

+53
-21
lines changed

efficientdet/tf2/efficientdet_keras_test.py

Lines changed: 53 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import efficientdet_arch as legacy_arch
2121
import hparams_config
2222
from tf2 import efficientdet_keras
23+
from tf2 import train_lib
2324

2425
SEED = 111111
2526

@@ -37,7 +38,6 @@ def test_model_output(self):
3738
inputs_shape = [1, 512, 512, 3]
3839
config = hparams_config.get_efficientdet_config('efficientdet-d0')
3940
config.heads = ['object_detection', 'segmentation']
40-
tmp_ckpt = os.path.join(tempfile.mkdtemp(), 'ckpt')
4141
with tf.Session(graph=tf.Graph()) as sess:
4242
feats = tf.ones(inputs_shape)
4343
tf.random.set_random_seed(SEED)
@@ -48,7 +48,6 @@ def test_model_output(self):
4848
grads = tf.nest.map_structure(lambda output: tf.gradients(output, feats),
4949
outputs)
5050
keras_class_grads, keras_box_grads, _ = sess.run(grads)
51-
model.save_weights(tmp_ckpt)
5251
with tf.Session(graph=tf.Graph()) as sess:
5352
feats = tf.ones(inputs_shape)
5453
tf.random.set_random_seed(SEED)
@@ -60,41 +59,57 @@ def test_model_output(self):
6059
legacy_class_grads, legacy_box_grads = sess.run(grads)
6160

6261
for i in range(3, 8):
63-
self.assertAllClose(
64-
keras_class_out[i - 3], legacy_class_out[i], rtol=1e-4, atol=1e-4)
65-
self.assertAllClose(
66-
keras_box_out[i - 3], legacy_box_out[i], rtol=1e-4, atol=1e-4)
67-
self.assertAllClose(
68-
keras_class_grads[i - 3], legacy_class_grads[i], rtol=1e-4, atol=1e-4)
69-
self.assertAllClose(
70-
keras_box_grads[i - 3], legacy_box_grads[i], rtol=1e-4, atol=1e-4)
62+
self.assertAllEqual(
63+
keras_class_out[i - 3], legacy_class_out[i])
64+
self.assertAllEqual(
65+
keras_box_out[i - 3], legacy_box_out[i])
66+
self.assertAllEqual(
67+
keras_class_grads[i - 3], legacy_class_grads[i])
68+
self.assertAllEqual(
69+
keras_box_grads[i - 3], legacy_box_grads[i])
7170

7271
def test_eager_output(self):
7372
inputs_shape = [1, 512, 512, 3]
7473
config = hparams_config.get_efficientdet_config('efficientdet-d0')
75-
config.heads = ['object_detection', 'segmentation']
74+
config.heads = ['object_detection']
7675
tmp_ckpt = os.path.join(tempfile.mkdtemp(), 'ckpt2')
7776

7877
with tf.Session(graph=tf.Graph()) as sess:
7978
feats = tf.ones(inputs_shape)
8079
tf.random.set_random_seed(SEED)
8180
model = efficientdet_keras.EfficientDetNet(config=config)
8281
outputs = model(feats, True)
82+
grads = tf.nest.map_structure(lambda output: tf.gradients(output, feats),
83+
outputs)
8384
sess.run(tf.global_variables_initializer())
84-
keras_class_out, keras_box_out, keras_seg_out = sess.run(outputs)
85+
keras_class_out, keras_box_out = sess.run(outputs)
86+
legacy_class_grads, legacy_box_grads = sess.run(grads)
8587
model.save_weights(tmp_ckpt)
8688

8789
feats = tf.ones(inputs_shape)
8890
model = efficientdet_keras.EfficientDetNet(config=config)
91+
model.build(inputs_shape)
8992
model.load_weights(tmp_ckpt)
90-
eager_class_out, eager_box_out, eager_seg_out = model(feats, True)
93+
94+
@tf.function
95+
def _run(feats):
96+
with tf.GradientTape(persistent=True) as tape:
97+
tape.watch(feats)
98+
eager_class_out, eager_box_out = model(feats, True)
99+
class_grads, box_grads = tf.nest.map_structure(
100+
lambda output: tape.gradient(output, feats),
101+
[eager_class_out, eager_box_out])
102+
return eager_class_out, eager_box_out, class_grads, box_grads
103+
eager_class_out, eager_box_out, class_grads, box_grads = _run(feats)
91104
for i in range(5):
92-
self.assertAllClose(
93-
eager_class_out[i], keras_class_out[i], rtol=1e-4, atol=1e-4)
94-
self.assertAllClose(
95-
eager_box_out[i], keras_box_out[i], rtol=1e-4, atol=1e-4)
96-
self.assertAllClose(
97-
eager_seg_out, keras_seg_out, rtol=1e-4, atol=1e-4)
105+
self.assertAllEqual(
106+
eager_class_out[i], keras_class_out[i])
107+
self.assertAllEqual(
108+
eager_box_out[i], keras_box_out[i])
109+
self.assertAllEqual(
110+
class_grads[i], legacy_class_grads[i][0])
111+
self.assertAllEqual(
112+
box_grads[i], legacy_box_grads[i][0])
98113

99114
def test_build_feature_network(self):
100115
config = hparams_config.get_efficientdet_config('efficientdet-d0')
@@ -130,8 +145,8 @@ def test_build_feature_network(self):
130145
legacy_grads = sess.run(grads[3:6])
131146

132147
for i in range(config.min_level, config.max_level + 1):
133-
self.assertAllClose(keras_feats[i - config.min_level], legacy_feats[i])
134-
self.assertAllClose(keras_grads[i - config.min_level],
148+
self.assertAllEqual(keras_feats[i - config.min_level], legacy_feats[i])
149+
self.assertAllEqual(keras_grads[i - config.min_level],
135150
legacy_grads[i - config.min_level])
136151

137152
def test_model_variables(self):
@@ -192,6 +207,23 @@ def test_resample_feature_map(self):
192207
actual_result = resample_layer(feat, training, all_feats)
193208
self.assertAllCloseAccordingToType(expect_result, actual_result)
194209

210+
def test_hub_model(self):
211+
image = tf.random.uniform((1, 320, 320, 3))
212+
keras_model = efficientdet_keras.EfficientDetNet('efficientdet-lite0')
213+
tmp_ckpt = os.path.join(tempfile.mkdtemp(), 'ckpt')
214+
keras_model.config.model_dir = tmp_ckpt
215+
base_model = train_lib.EfficientDetNetTrainHub(keras_model.config,
216+
"https://tfhub.dev/tensorflow/efficientdet/lite0/feature-vector/1")
217+
cls_outputs, box_outputs = tf.function(base_model)(image, training=False)
218+
keras_model.build(image.shape)
219+
d1 = {var.name: var for var in base_model.variables}
220+
for var in keras_model.variables:
221+
var.assign(d1[var.name].numpy())
222+
cls_outputs2, box_outputs2 = tf.function(keras_model)(image, False)
223+
for c1, b1, c2, b2 in zip(cls_outputs, box_outputs, cls_outputs2, box_outputs2):
224+
self.assertAllEqual(c1, c2)
225+
self.assertAllEqual(b1, b2)
226+
195227
def test_resample_var_names(self):
196228
with tf.Graph().as_default():
197229
feat = tf.random.uniform([1, 16, 16, 320])

0 commit comments

Comments
 (0)