Skip to content

Commit da4dc73

Browse files
committed
Add a simple get_model for finetuning and easy of use.
1 parent 49e2097 commit da4dc73

File tree

2 files changed

+112
-73
lines changed

2 files changed

+112
-73
lines changed

efficientnetv2/effnetv2_model.py

Lines changed: 84 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -459,13 +459,6 @@ def __init__(self, mconfig, name=None):
459459

460460
self._avg_pooling = tf.keras.layers.GlobalAveragePooling2D(
461461
data_format=mconfig.data_format)
462-
if mconfig.num_classes:
463-
self._fc = tf.keras.layers.Dense(
464-
mconfig.num_classes,
465-
kernel_initializer=dense_kernel_initializer,
466-
bias_initializer=tf.constant_initializer(mconfig.headbias or 0))
467-
else:
468-
self._fc = None
469462

470463
if mconfig.dropout_rate > 0:
471464
self._dropout = tf.keras.layers.Dropout(mconfig.dropout_rate)
@@ -498,9 +491,6 @@ def call(self, inputs, training):
498491
self.endpoints['pooled_features'] = outputs
499492
if self._dropout:
500493
outputs = self._dropout(outputs, training=training)
501-
self.endpoints['global_pool'] = outputs
502-
if self._fc:
503-
outputs = self._fc(outputs)
504494
self.endpoints['head'] = outputs
505495
return outputs
506496

@@ -514,12 +504,13 @@ class EffNetV2Model(tf.keras.Model):
514504
def __init__(self,
515505
model_name='efficientnetv2-s',
516506
model_config=None,
507+
include_top=True,
517508
name=None):
518509
"""Initializes an `Model` instance.
519510
520511
Args:
521512
model_name: A string of model name.
522-
model_config: A dict of model configureations or a string of hparams.
513+
model_config: A dict of model configurations or a string of hparams.
523514
name: A string of layer name.
524515
525516
Raises:
@@ -533,6 +524,7 @@ def __init__(self,
533524
self.cfg = cfg
534525
self._mconfig = cfg.model
535526
self.endpoints = None
527+
self.include_top = include_top
536528
self._build()
537529

538530
def _build(self):
@@ -574,12 +566,25 @@ def _build(self):
574566
# Head part.
575567
self._head = Head(self._mconfig)
576568

569+
# top part for classification
570+
if self.include_top and self._mconfig.num_classes:
571+
self._fc = tf.keras.layers.Dense(
572+
self._mconfig.num_classes,
573+
kernel_initializer=dense_kernel_initializer,
574+
bias_initializer=tf.constant_initializer(self._mconfig.headbias or 0))
575+
else:
576+
self._fc = None
577+
577578
def summary(self, input_shape=(224, 224, 3), **kargs):
578579
x = tf.keras.Input(shape=input_shape)
579580
model = tf.keras.Model(inputs=[x], outputs=self.call(x, training=True))
580581
return model.summary()
581582

582-
def call(self, inputs, training, features_only=None, single_out=None):
583+
def get_model_with_inputs(self, inputs, **kargs):
584+
model = tf.keras.Model(inputs=[inputs], outputs=self.call(inputs, training=True))
585+
return model
586+
587+
def call(self, inputs, training, with_endpoints=False):
583588
"""Implementation of call().
584589
585590
Args:
@@ -624,19 +629,70 @@ def call(self, inputs, training, features_only=None, single_out=None):
624629
self.endpoints['reduction_%s/%s' % (reduction_idx, k)] = v
625630
self.endpoints['features'] = outputs
626631

627-
if not features_only:
628-
# Calls final layers and returns logits.
629-
outputs = self._head(outputs, training)
630-
self.endpoints.update(self._head.endpoints)
631-
632-
if single_out: # Use for building sequential models.
633-
return outputs
634-
635-
return [outputs] + list(
636-
filter(lambda endpoint: endpoint is not None, [
637-
self.endpoints.get('reduction_1'),
638-
self.endpoints.get('reduction_2'),
639-
self.endpoints.get('reduction_3'),
640-
self.endpoints.get('reduction_4'),
641-
self.endpoints.get('reduction_5'),
642-
]))
632+
# Head to obtain the final feature.
633+
outputs = self._head(outputs, training)
634+
self.endpoints.update(self._head.endpoints)
635+
636+
# Calls final dense layers and returns logits.
637+
if self._fc:
638+
with tf.name_scope('head'): # legacy
639+
outputs = self._fc(outputs)
640+
641+
if with_endpoints: # Use for building sequential models.
642+
return [outputs] + list(
643+
filter(lambda endpoint: endpoint is not None, [
644+
self.endpoints.get('reduction_1'),
645+
self.endpoints.get('reduction_2'),
646+
self.endpoints.get('reduction_3'),
647+
self.endpoints.get('reduction_4'),
648+
self.endpoints.get('reduction_5'),
649+
]))
650+
651+
return outputs
652+
653+
654+
def get_model(model_name,
655+
model_config=None,
656+
include_top=True,
657+
pretrained=True,
658+
training=True,
659+
with_endpoints=False,
660+
**kargs):
661+
"""Get a EfficientNet V1 or V2 model instance.
662+
663+
This is a simply utility for finetuning or inference.
664+
665+
Args:
666+
model_name: a string such as 'efficientnetv2-s' or 'efficientnet-b0'.
667+
model_config: A dict of model configurations or a string of hparams.
668+
include_top: whether to include the final dense layer for classification.
669+
pretrained: if true, download the checkpoint. If string, load the ckpt.
670+
training: If true, all model variables are trainable.
671+
with_endpoints: whether to return all intermedia endpoints.
672+
673+
Returns:
674+
A single tensor if with_endpoints if False; otherwise, a list of tensor.
675+
"""
676+
net = EffNetV2Model(model_name, model_config, include_top)
677+
net(tf.keras.Input(shape=(None, None, 3)),
678+
training=training,
679+
with_endpoints=with_endpoints)
680+
if pretrained is True:
681+
# download checkpoint and set pretrained path. Supported models include:
682+
# efficientnetv2-s, efficientnetv2-m, efficientnetv2-l,
683+
# efficientnetv2-b0, efficientnetv2-b1, efficientnetv2-b2, efficientnetv2-b3,
684+
# efficientnet-b0, efficientnet-b1, efficientnet-b2, efficientnet-b3,
685+
# efficientnet-b4, efficientnet-b5, efficientnet-b6, efficientnet-b7, efficientnet-l2
686+
# More V2 ckpts: https://github.com/google/automl/tree/master/efficientnetv2
687+
# More V1 ckpts: https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet
688+
url = f'https://storage.googleapis.com/cloud-tpu-checkpoints/efficientnet/v2/{model_name}.tgz'
689+
pretrained_ckpt = tf.keras.utils.get_file(model_name, url, untar=True)
690+
else:
691+
pretrained_ckpt = pretrained
692+
693+
if pretrained_ckpt:
694+
if tf.io.gfile.isdir(pretrained_ckpt):
695+
pretrained_ckpt = tf.train.latest_checkpoint(pretrained_ckpt)
696+
net.load_weights(pretrained_ckpt)
697+
698+
return net

efficientnetv2/infer.py

Lines changed: 28 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -44,94 +44,79 @@ def define_flags():
4444
flags.DEFINE_string('export_dir', None, 'Export or saved model directory')
4545
flags.DEFINE_string('trace_file', '/tmp/a.trace', 'If set, dump trace file.')
4646
flags.DEFINE_integer('batch_size', 16, 'Batch size.')
47+
flags.DEFINE_bool('mixed_precision', False, 'If True, use mixed precision.')
4748

4849

4950
def get_config(model_name, dataset_cfg, hparam_str=''):
5051
"""Create a keras model for EffNetV2."""
5152
config = copy.deepcopy(effnetv2_configs.get_model_config(model_name))
52-
config.override(datasets.get_dataset_config(dataset_cfg))
53-
config.override(hparam_str)
53+
config.update(datasets.get_dataset_config(dataset_cfg))
54+
config.override(hparam_str, allow_new_keys=True)
5455
config.model.num_classes = config.data.num_classes
5556
return config
5657

5758

5859
def build_tf2_model():
5960
"""Build the tf2 model."""
6061
tf.config.run_functions_eagerly(FLAGS.debug)
61-
config = get_config(FLAGS.model_name, FLAGS.dataset_cfg, FLAGS.hparam_str)
62-
if config.runtime.mixed_precision:
62+
if FLAGS.mixed_precision:
6363
# Use 'mixed_float16' if running on GPUs.
6464
policy = tf.keras.mixed_precision.Policy('mixed_float16')
6565
tf.keras.mixed_precision.set_global_policy(policy)
6666

67-
model = effnetv2_model.EffNetV2Model(FLAGS.model_name, config.model)
68-
# Use call (not build) to match the namescope: tensorflow issues/29576
69-
model(tf.ones([1, 224, 224, 3]), False)
70-
if FLAGS.model_dir:
71-
ckpt = FLAGS.model_dir
72-
if tf.io.gfile.isdir(ckpt):
73-
ckpt = tf.train.latest_checkpoint(FLAGS.model_dir)
74-
model.load_weights(ckpt)
67+
model = effnetv2_model.get_model(
68+
FLAGS.model_name, FLAGS.hparam_str, include_top=True, pretrained=FLAGS.model_dir or True)
7569
model.summary()
76-
77-
class ExportModel(tf.Module):
78-
"""Export a saved model."""
79-
80-
def __init__(self, model):
81-
super().__init__()
82-
self.model = model
83-
84-
@tf.function
85-
def f(self, images):
86-
return self.model(images, training=False)[0]
87-
88-
return ExportModel(model)
89-
70+
return model
9071

9172
def tf2_eval_dataset():
9273
"""Run TF2 benchmark and inference."""
93-
export_model = build_tf2_model()
94-
isize = FLAGS.image_size or export_model.model.cfg.eval.isize
74+
model = build_tf2_model()
75+
isize = FLAGS.image_size or model.cfg.eval.isize
9576

9677
def preprocess_fn(features):
9778
features['image'] = preprocessing.preprocess_image(
9879
features['image'], isize, is_training=False)
9980
return features
10081

82+
@tf.function
83+
def f(x):
84+
return model(x)
85+
10186
top1_acc = tf.keras.metrics.Accuracy()
10287
pbar = tf.keras.utils.Progbar(None)
10388
data = tfds.load('imagenet2012', split='validation')
10489
ds = data.map(preprocess_fn).batch(FLAGS.batch_size)
10590
for i, batch in enumerate(ds.prefetch(tf.data.experimental.AUTOTUNE)):
106-
logits = export_model.f(batch['image'])
91+
logits = f(batch['image'])
10792
top1_acc.update_state(batch['label'], tf.argmax(logits, axis=-1))
10893
pbar.update(i, [('top1', top1_acc.result().numpy())])
10994
print('\n top1= {:.4f}'.format(top1_acc.result().numpy()))
11095

11196

11297
def tf2_benchmark():
11398
"""Run TF2 benchmark and inference."""
114-
export_model = build_tf2_model()
115-
isize = FLAGS.image_size or export_model.model.cfg.eval.isize
99+
model = build_tf2_model()
100+
isize = FLAGS.image_size or model.cfg.eval.isize
116101
if FLAGS.export_dir:
117-
tf.saved_model.save(
118-
export_model,
119-
FLAGS.export_dir,
120-
signatures=export_model.f.get_concrete_function(
121-
tf.TensorSpec(shape=(None, isize, isize, 3), dtype=tf.float16)))
122-
export_model = tf.saved_model.load(FLAGS.export_dir)
102+
tf.saved_model.save(model, FLAGS.export_dir)
103+
model = tf.saved_model.load(FLAGS.export_dir)
123104

124105
batch_size = FLAGS.batch_size
125106
imgs = tf.ones((batch_size, isize, isize, 3), dtype=tf.float16)
126107

108+
@tf.function
109+
def f(x):
110+
return model(x)
111+
127112
print('starting warmup.')
128113
for _ in range(10): # warmup runs.
129-
export_model.f(imgs)
114+
f(imgs)
130115

131116
print('start benchmark.')
132117
start = time.perf_counter()
133118
for _ in range(10):
134-
export_model.f(imgs)
119+
f(imgs)
135120
end = time.perf_counter()
136121
inference_time = (end - start) / 10
137122

@@ -143,14 +128,13 @@ def tf1_benchmark():
143128
"""Run TF1 inference and benchmark."""
144129
# pylint: disable=g-direct-tensorflow-import,g-import-not-at-top
145130
from tensorflow.python.client import timeline
146-
config = get_config(FLAGS.model_name, FLAGS.dataset_cfg, FLAGS.hparam_str)
147131
with tf1.Session() as sess:
148-
model = effnetv2_model.EffNetV2Model(FLAGS.model_name, config.model)
132+
model = effnetv2_model.EffNetV2Model(FLAGS.model_name, FLAGS.hparam_str)
149133
batch_size = FLAGS.batch_size
150134
run_options = tf1.RunOptions(
151135
trace_level=tf1.RunOptions.FULL_TRACE)
152136
run_metadata = tf1.RunMetadata()
153-
isize = FLAGS.image_size or config.eval.isize
137+
isize = FLAGS.image_size or model.cfg.eval.isize
154138
inputs = tf.ones((batch_size, isize, isize, 3), tf.float16)
155139
output = model(inputs, training=False)
156140
sess.run(tf1.global_variables_initializer())
@@ -179,10 +163,9 @@ def tf1_benchmark():
179163
def tf1_export_ema_ckpt():
180164
"""Restore variables from a given checkpoint."""
181165
with tf1.Session() as sess:
182-
config = get_config(FLAGS.model_name, FLAGS.dataset_cfg, FLAGS.hparam_str)
183-
model = effnetv2_model.EffNetV2Model(FLAGS.model_name, config.model)
166+
model = effnetv2_model.EffNetV2Model(FLAGS.model_name, FLAGS.hparam_str)
184167
batch_size = FLAGS.batch_size
185-
isize = FLAGS.image_size or config.eval.isize
168+
isize = FLAGS.image_size or model.cfg.eval.isize
186169
inputs = tf.ones((batch_size, isize, isize, 3), tf.float32)
187170
_ = model(inputs, training=False)
188171
sess.run(tf1.global_variables_initializer())

0 commit comments

Comments
 (0)