Skip to content

Commit 6bddef7

Browse files
committed
Fix the argument name and savedmodel precision issue.
Fix #1077 and #1079
1 parent c2ce63a commit 6bddef7

File tree

2 files changed

+7
-5
lines changed

2 files changed

+7
-5
lines changed

efficientnetv2/effnetv2_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -585,7 +585,7 @@ def get_model_with_inputs(self, inputs, **kargs):
585585
inputs=[inputs], outputs=self.call(inputs, training=True))
586586
return model
587587

588-
def call(self, inputs, training, with_endpoints=False):
588+
def call(self, inputs, training=False, with_endpoints=False):
589589
"""Implementation of call().
590590
591591
Args:

efficientnetv2/infer.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def build_tf2_model():
5656
FLAGS.model_name,
5757
FLAGS.hparam_str,
5858
include_top=True,
59-
pretrained=FLAGS.model_dir or True)
59+
weights=FLAGS.model_dir or 'imagenet')
6060
model.summary()
6161
return model
6262

@@ -95,11 +95,12 @@ def tf2_benchmark():
9595
model = tf.saved_model.load(FLAGS.export_dir)
9696

9797
batch_size = FLAGS.batch_size
98-
imgs = tf.ones((batch_size, isize, isize, 3), dtype=tf.float16)
98+
data_dtype = tf.float16 if FLAGS.mixed_precision else tf.float32
99+
imgs = tf.ones((batch_size, isize, isize, 3), dtype=data_dtype)
99100

100101
@tf.function
101102
def f(x):
102-
return model(x)
103+
return model(x, training=False)
103104

104105
print('starting warmup.')
105106
for _ in range(10): # warmup runs.
@@ -126,7 +127,8 @@ def tf1_benchmark():
126127
run_options = tf1.RunOptions(trace_level=tf1.RunOptions.FULL_TRACE)
127128
run_metadata = tf1.RunMetadata()
128129
isize = FLAGS.image_size or model.cfg.eval.isize
129-
inputs = tf.ones((batch_size, isize, isize, 3), tf.float16)
130+
data_dtype = tf.float16 if FLAGS.mixed_precision else tf.float32
131+
inputs = tf.ones((batch_size, isize, isize, 3), data_dtype)
130132
output = model(inputs, training=False)
131133
sess.run(tf1.global_variables_initializer())
132134

0 commit comments

Comments
 (0)