Skip to content

Commit 861975e

Browse files
committed
fix broken test cases
1 parent 4be6845 commit 861975e

File tree

6 files changed

+37
-21
lines changed

6 files changed

+37
-21
lines changed

efficientdet/tf2/efficientdet_keras.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -451,8 +451,8 @@ def conv2d_layer(cls, separable_conv, data_format):
451451
tf.keras.layers.SeparableConv2D,
452452
depth_multiplier=1,
453453
data_format=data_format,
454-
pointwise_initializer=tf.initializers.variance_scaling(),
455-
depthwise_initializer=tf.initializers.variance_scaling())
454+
pointwise_initializer='variance_scaling',
455+
depthwise_initializer='variance_scaling')
456456
else:
457457
conv2d_layer = functools.partial(
458458
tf.keras.layers.Conv2D,
@@ -537,8 +537,8 @@ def __init__(self,
537537
tf.keras.layers.SeparableConv2D(
538538
filters=self.num_filters,
539539
depth_multiplier=1,
540-
pointwise_initializer=tf.initializers.variance_scaling(),
541-
depthwise_initializer=tf.initializers.variance_scaling(),
540+
pointwise_initializer='variance_scaling',
541+
depthwise_initializer='variance_scaling',
542542
data_format=self.data_format,
543543
kernel_size=3,
544544
activation=None,
@@ -612,8 +612,8 @@ def boxes_layer(cls, separable_conv, num_anchors, data_format, name):
612612
return tf.keras.layers.SeparableConv2D(
613613
filters=4 * num_anchors,
614614
depth_multiplier=1,
615-
pointwise_initializer=tf.initializers.variance_scaling(),
616-
depthwise_initializer=tf.initializers.variance_scaling(),
615+
pointwise_initializer='variance_scaling',
616+
depthwise_initializer='variance_scaling',
617617
data_format=data_format,
618618
kernel_size=3,
619619
activation=None,

efficientdet/tf2/efficientdet_keras_test.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import tempfile
1818
from absl import logging
1919
import tensorflow.compat.v1 as tf
20+
import tensorflow.compat.v2 as tf2
2021
import efficientdet_arch as legacy_arch
2122
import hparams_config
2223
from tf2 import efficientdet_keras
@@ -38,26 +39,37 @@ def test_model_output(self):
3839
inputs_shape = [1, 512, 512, 3]
3940
config = hparams_config.get_efficientdet_config('efficientdet-d0')
4041
config.heads = ['object_detection', 'segmentation']
42+
tf2.keras.utils.set_random_seed(SEED)
4143
with tf.Session(graph=tf.Graph()) as sess:
4244
feats = tf.ones(inputs_shape)
43-
tf.random.set_random_seed(SEED)
4445
model = efficientdet_keras.EfficientDetNet(config=config)
4546
outputs = model(feats, True)
4647
sess.run(tf.global_variables_initializer())
4748
keras_class_out, keras_box_out, _ = sess.run(outputs)
4849
grads = tf.nest.map_structure(lambda output: tf.gradients(output, feats),
4950
outputs)
5051
keras_class_grads, keras_box_grads, _ = sess.run(grads)
52+
vars = list(filter(
53+
lambda var: not var.name.startswith('segmentation'),
54+
tf.global_variables()))
55+
vars.sort(key=lambda var: var.name)
56+
keras_vars_names = [var.name for var in vars]
57+
keras_vars_values = sess.run(vars)
58+
5159
with tf.Session(graph=tf.Graph()) as sess:
5260
feats = tf.ones(inputs_shape)
53-
tf.random.set_random_seed(SEED)
5461
outputs = legacy_arch.efficientdet(feats, config=config)
55-
sess.run(tf.global_variables_initializer())
62+
vars = tf.global_variables()
63+
vars.sort(key=lambda var: var.name)
64+
legacy_vars_names = [var.name for var in vars]
65+
sess.run([var.assign(val) for val, var in zip(keras_vars_values, vars)])
5666
legacy_class_out, legacy_box_out = sess.run(outputs)
5767
grads = tf.nest.map_structure(lambda output: tf.gradients(output, feats),
5868
outputs)
5969
legacy_class_grads, legacy_box_grads = sess.run(grads)
6070

71+
self.assertAllEqual(keras_vars_names, legacy_vars_names)
72+
6173
for i in range(3, 8):
6274
self.assertAllEqual(
6375
keras_class_out[i - 3], legacy_class_out[i])
@@ -76,7 +88,7 @@ def test_eager_output(self):
7688

7789
with tf.Session(graph=tf.Graph()) as sess:
7890
feats = tf.ones(inputs_shape)
79-
tf.random.set_random_seed(SEED)
91+
tf2.keras.utils.set_random_seed(SEED)
8092
model = efficientdet_keras.EfficientDetNet(config=config)
8193
outputs = model(feats, True)
8294
grads = tf.nest.map_structure(lambda output: tf.gradients(output, feats),
@@ -120,7 +132,7 @@ def test_build_feature_network(self):
120132
tf.ones([1, 32, 32, 112]), # level 4
121133
tf.ones([1, 16, 16, 320]), # level 5
122134
]
123-
tf.random.set_random_seed(SEED)
135+
tf2.keras.utils.set_random_seed(SEED)
124136
fpn_cell = efficientdet_keras.FPNCells(config)
125137
new_feats1 = fpn_cell(inputs, True)
126138
sess.run(tf.global_variables_initializer())
@@ -137,7 +149,7 @@ def test_build_feature_network(self):
137149
4: tf.ones([1, 32, 32, 112]),
138150
5: tf.ones([1, 16, 16, 320])
139151
}
140-
tf.random.set_random_seed(SEED)
152+
tf2.keras.utils.set_random_seed(SEED)
141153
new_feats2 = legacy_arch.build_feature_network(inputs, config)
142154
sess.run(tf.global_variables_initializer())
143155
legacy_feats = sess.run(new_feats2)
@@ -185,7 +197,7 @@ def test_resample_feature_map(self):
185197
for strategy in ['tpu', '']:
186198
with self.subTest(
187199
apply_bn=apply_bn, training=training, strategy=strategy):
188-
tf.random.set_random_seed(SEED)
200+
tf2.keras.utils.set_random_seed(SEED)
189201
expect_result = legacy_arch.resample_feature_map(
190202
feat,
191203
name='resample_p0',
@@ -195,7 +207,7 @@ def test_resample_feature_map(self):
195207
apply_bn=apply_bn,
196208
is_training=training,
197209
strategy=strategy)
198-
tf.random.set_random_seed(SEED)
210+
tf2.keras.utils.set_random_seed(SEED)
199211
resample_layer = efficientdet_keras.ResampleFeatureMap(
200212
name='resample_p0',
201213
feat_level=0,

efficientdet/tf2/infer_lib.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,7 @@ def predict(self, image_arrays):
214214
Returns:
215215
Model outputs.
216216
"""
217-
raise NotImplemented
217+
raise NotImplementedError
218218

219219
def _preprocess(self, image_arrays):
220220

@@ -249,7 +249,7 @@ def serve(self, image_arrays):
249249
Returns:
250250
A list of detections.
251251
"""
252-
raise NotImplemented
252+
raise NotImplementedError
253253

254254
def _get_model_and_spec(self, tflite=None):
255255
"""Get model instance and export spec."""

efficientdet/tf2/infer_lib_test.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def test_infer_lib(self):
9090
images = tf.ones((1, 512, 512, 3))
9191
boxes, scores, classes, valid_lens = driver.serve(images)
9292
self.assertEqual(tf.reduce_mean(boxes), 163.09)
93-
self.assertEqual(tf.reduce_mean(scores), 0.01000005)
93+
self.assertEqual(tf.reduce_mean(scores), 0.01)
9494
self.assertEqual(tf.reduce_mean(classes), 1)
9595
self.assertEqual(tf.reduce_mean(valid_lens), 100)
9696
self.assertEqual(boxes.shape, (1, 100, 4))
@@ -107,7 +107,7 @@ def test_infer_lib_without_ema(self):
107107
images = tf.ones((1, 512, 512, 3))
108108
boxes, scores, classes, valid_lens = driver.serve(images)
109109
self.assertEqual(tf.reduce_mean(boxes), 163.09)
110-
self.assertEqual(tf.reduce_mean(scores), 0.01000005)
110+
self.assertEqual(tf.reduce_mean(scores), 0.01)
111111
self.assertEqual(tf.reduce_mean(classes), 1)
112112
self.assertEqual(tf.reduce_mean(valid_lens), 100)
113113
self.assertEqual(boxes.shape, (1, 100, 4))
@@ -134,7 +134,7 @@ def test_infer_lib_mixed_precision(self):
134134
policy = tf.keras.mixed_precision.global_policy()
135135
if policy.name == 'float32':
136136
self.assertEqual(tf.reduce_mean(boxes), 163.09)
137-
self.assertEqual(tf.reduce_mean(scores), 0.01000005)
137+
self.assertEqual(tf.reduce_mean(scores), 0.01)
138138
self.assertEqual(tf.reduce_mean(classes), 1)
139139
self.assertEqual(tf.reduce_mean(valid_lens), 100)
140140
elif policy.name == 'float16':

efficientdet/tf2/train_lib.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -661,7 +661,10 @@ def train_step(self, data):
661661
scaled_loss = total_loss
662662
optimizer = self.optimizer
663663
loss_vals['loss'] = total_loss
664-
loss_vals['learning_rate'] = optimizer.learning_rate(optimizer.iterations)
664+
if callable(optimizer.learning_rate):
665+
loss_vals['learning_rate'] = optimizer.learning_rate(optimizer.iterations)
666+
else:
667+
loss_vals['learning_rate'] = optimizer.learning_rate
665668
trainable_vars = self._freeze_vars()
666669
scaled_gradients = tape.gradient(scaled_loss, trainable_vars)
667670
if isinstance(self.optimizer,

efficientdet/tf2/train_lib_test.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -196,8 +196,8 @@ def test_fit(self):
196196
# skip gnorm test because it is flaky.
197197

198198
def test_recompute_grad(self):
199-
tf.config.run_functions_eagerly(True)
200199
_, x, labels, model = self._build_model(False)
200+
weights = model.get_weights()
201201
with tf.GradientTape() as tape:
202202
loss_vals = {}
203203
cls_outputs, box_outputs, _ = model(x, training=True)
@@ -206,6 +206,7 @@ def test_recompute_grad(self):
206206
grads1 = tape.gradient(det_loss, model.trainable_variables)
207207

208208
_, x, labels, model = self._build_model(True)
209+
model.set_weights(weights)
209210
with tf.GradientTape() as tape:
210211
loss_vals = {}
211212
cls_outputs2, box_outputs2, _ = model(x, training=True)

0 commit comments

Comments
 (0)