@@ -56,7 +56,7 @@ def build_tf2_model():
5656 FLAGS .model_name ,
5757 FLAGS .hparam_str ,
5858 include_top = True ,
59- pretrained = FLAGS .model_dir or True )
59+ weights = FLAGS .model_dir or 'imagenet' )
6060 model .summary ()
6161 return model
6262
@@ -95,11 +95,12 @@ def tf2_benchmark():
9595 model = tf .saved_model .load (FLAGS .export_dir )
9696
9797 batch_size = FLAGS .batch_size
98- imgs = tf .ones ((batch_size , isize , isize , 3 ), dtype = tf .float16 )
98+ data_dtype = tf .float16 if FLAGS .mixed_precision else tf .float32
99+ imgs = tf .ones ((batch_size , isize , isize , 3 ), dtype = data_dtype )
99100
100101 @tf .function
101102 def f (x ):
102- return model (x )
103+ return model (x , training = False )
103104
104105 print ('starting warmup.' )
105106 for _ in range (10 ): # warmup runs.
@@ -126,7 +127,8 @@ def tf1_benchmark():
126127 run_options = tf1 .RunOptions (trace_level = tf1 .RunOptions .FULL_TRACE )
127128 run_metadata = tf1 .RunMetadata ()
128129 isize = FLAGS .image_size or model .cfg .eval .isize
129- inputs = tf .ones ((batch_size , isize , isize , 3 ), tf .float16 )
130+ data_dtype = tf .float16 if FLAGS .mixed_precision else tf .float32
131+ inputs = tf .ones ((batch_size , isize , isize , 3 ), data_dtype )
130132 output = model (inputs , training = False )
131133 sess .run (tf1 .global_variables_initializer ())
132134
0 commit comments