Skip to content

Commit c84956a

Browse files
committed
Fix lite backbone builder.
Fix #357 Also add a few simple code cleanup.
1 parent cf91771 commit c84956a

File tree

5 files changed

+14
-16
lines changed

5 files changed

+14
-16
lines changed

efficientdet/backbone/efficientnet_builder_test.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def _test_model_params(self,
3838
images,
3939
model_name=model_name,
4040
override_params=override_params,
41-
training=True,
41+
training=False,
4242
features_only=features_only,
4343
pooled_features_only=pooled_features_only)
4444
num_params = np.sum([np.prod(v.shape) for v in tf.trainable_variables()])
@@ -91,15 +91,15 @@ def test_efficientnet_b0_fails_if_both_features_requested(self):
9191
efficientnet_builder.build_model(
9292
None,
9393
model_name='efficientnet-b0',
94-
training=True,
94+
training=False,
9595
features_only=True,
9696
pooled_features_only=True)
9797

9898
def test_efficientnet_b0_base(self):
9999
# Creates a base model using the model configuration.
100100
images = tf.zeros((1, 224, 224, 3), dtype=tf.float32)
101101
_, endpoints = efficientnet_builder.build_model_base(
102-
images, model_name='efficientnet-b0', training=True)
102+
images, model_name='efficientnet-b0', training=False)
103103

104104
# reduction_1 to reduction_5 should be in endpoints
105105
self.assertIn('reduction_1', endpoints)

efficientdet/backbone/efficientnet_lite_builder.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -161,13 +161,12 @@ def build_model(images,
161161
f.write('global_params= %s\n\n' % str(global_params))
162162
f.write('blocks_args= %s\n\n' % str(blocks_args))
163163

164-
with tf.variable_scope(model_name):
165-
model = efficientnet_model.Model(blocks_args, global_params)
166-
outputs = model(
167-
images,
168-
training=training,
169-
features_only=features_only,
170-
pooled_features_only=pooled_features_only)
164+
model = efficientnet_model.Model(blocks_args, global_params, model_name)
165+
outputs = model(
166+
images,
167+
training=training,
168+
features_only=features_only,
169+
pooled_features_only=pooled_features_only)
171170
if features_only:
172171
outputs = tf.identity(outputs, 'features')
173172
elif pooled_features_only:
@@ -202,9 +201,8 @@ def build_model_base(images, model_name, training, override_params=None):
202201

203202
blocks_args, global_params = get_model_params(model_name, override_params)
204203

205-
with tf.variable_scope(model_name):
206-
model = efficientnet_model.Model(blocks_args, global_params)
207-
features = model(images, training=training, features_only=True)
204+
model = efficientnet_model.Model(blocks_args, global_params, model_name)
205+
features = model(images, training=training, features_only=True)
208206

209207
features = tf.identity(features, 'features')
210208
return features, model.endpoints

efficientdet/backbone/efficientnet_lite_builder_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def _test_model_params(self,
3838
images,
3939
model_name=model_name,
4040
override_params=override_params,
41-
training=True,
41+
training=False,
4242
features_only=features_only,
4343
pooled_features_only=pooled_features_only)
4444
num_params = np.sum([np.prod(v.shape) for v in tf.trainable_variables()])

efficientdet/backbone/efficientnet_model_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def test_bottleneck_block(self):
3535
0,
3636
'channels_last',
3737
num_classes=10,
38-
batch_norm=utils.TpuBatchNormalization)
38+
batch_norm=utils.batch_norm_class(False))
3939
blocks_args = [
4040
efficientnet_model.BlockArgs(
4141
kernel_size=3,

efficientdet/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,7 @@ def call(self, *args, **kwargs):
236236
return outputs
237237

238238

239-
def batch_norm_class(is_training, use_tpu=False,):
239+
def batch_norm_class(is_training, use_tpu=False):
240240
if is_training and use_tpu:
241241
return TpuBatchNormalization
242242
else:

0 commit comments

Comments
 (0)