Skip to content

Commit f2b4480

Browse files
committed
A few minor fixes for model call.
1 parent 94f47cd commit f2b4480

File tree

3 files changed

+5
-16
lines changed

3 files changed

+5
-16
lines changed

efficientnetv2/main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def build_model(in_images):
8484
"""Build model using the model_name given through the command line."""
8585
config.model.num_classes = config.data.num_classes
8686
model = effnetv2_model.EffNetV2Model(config.model.model_name, config.model)
87-
logits = model(in_images, training=is_training)[0]
87+
logits = model(in_images, training=is_training)
8888
return logits
8989

9090
pre_num_params, pre_num_flops = utils.num_params_flops(readable_format=True)

efficientnetv2/main_tf2.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def train_step(self, data):
9191
images, labels = features['image'], labels['label']
9292

9393
with tf.GradientTape() as tape:
94-
pred = self(images, training=True)[0]
94+
pred = self(images, training=True)
9595
pred = tf.cast(pred, tf.float32)
9696
loss = self.compiled_loss(
9797
labels,
@@ -105,7 +105,7 @@ def train_step(self, data):
105105
def test_step(self, data):
106106
features, labels = data
107107
images, labels = features['image'], labels['label']
108-
pred = self(images, training=False)[0]
108+
pred = self(images, training=False)
109109
pred = tf.cast(pred, tf.float32)
110110

111111
self.compiled_loss(
@@ -174,9 +174,9 @@ def main(_) -> None:
174174
weight_decay=config.train.weight_decay)
175175

176176
if config.train.ft_init_ckpt: # load pretrained ckpt for finetuning.
177-
model(tf.ones([1, 224, 224, 3]))
177+
model(tf.keras.Input([None, None, 3]))
178178
ckpt = config.train.ft_init_ckpt
179-
utils.restore_tf2_ckpt(model, ckpt, exclude_layers=('_head', 'optimizer'))
179+
utils.restore_tf2_ckpt(model, ckpt, exclude_layers=('_fc', 'optimizer'))
180180

181181
steps_per_epoch = num_train_images // config.train.batch_size
182182
total_steps = steps_per_epoch * config.train.epochs

efficientnetv2/utils.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -203,9 +203,6 @@ def _moments(self, inputs, reduction_axes, keep_dims):
203203

204204
def call(self, inputs, training=None):
205205
outputs = super().call(inputs, training)
206-
# A temporary hack for tf1 compatibility with keras batch norm.
207-
for u in self.updates:
208-
tf.compat.v1.add_to_collection(tf.compat.v1.GraphKeys.UPDATE_OPS, u)
209206
return outputs
210207

211208

@@ -217,14 +214,6 @@ def __init__(self, **kwargs):
217214
kwargs['name'] = 'tpu_batch_normalization'
218215
super().__init__(**kwargs)
219216

220-
def call(self, inputs, training=None):
221-
outputs = super().call(inputs, training)
222-
if training and not tf.executing_eagerly():
223-
# A temporary hack for tf1 compatibility with keras batch norm.
224-
for u in self.updates:
225-
tf.compat.v1.add_to_collection(tf.compat.v1.GraphKeys.UPDATE_OPS, u)
226-
return outputs
227-
228217

229218
def normalization(norm_type: str,
230219
axis=-1,

0 commit comments

Comments
 (0)