Skip to content

Commit e9ec14e

Browse files
author
Peter Steinbach
authored
Merge pull request #12 from psteinb/try-tf-1.6
Try tf 1.6
2 parents 98fa9c6 + 227a4d2 commit e9ec14e

File tree

5 files changed

+14
-12
lines changed

5 files changed

+14
-12
lines changed

models/__init__.py

Whitespace-only changes.

models/resnet.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ def versions(self):
121121

122122
value = ""
123123

124-
if "keras" in self.backend.lower():
124+
if self.backend.lower().startswith("keras"):
125125

126126
import keras
127127
from keras import backend as K
@@ -143,7 +143,7 @@ def versions(self):
143143

144144
else:
145145

146-
if "tensorflow" in self.backend.lower():
146+
if self.backend.lower() == "tensorflow" or self.backend.lower() == "tf":
147147
import tensorflow as tf
148148
value = "tensorflow:{ver}".format(ver=tf.__version__)
149149
else:

models/tf_details/resnet_details.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def can_train():
3333
warnings.simplefilter(action='ignore', category=FutureWarning)
3434

3535
from tensorflow import __version__ as tfv
36-
required = "1.7.0"
36+
required = "1.6.0"
3737

3838
#only require major and minor release number as the patch number may contain 'rc' etc
3939
if versiontuple(tfv,2) >= versiontuple(required,2):
@@ -111,4 +111,9 @@ def train(train, test, datafraction, opts):
111111
logging.info('handing over \n >> %s \n >> %s',flags,opts)
112112
history, timings = run_loop.resnet_main(flags, cfmain.cifar10_model_fn, cfmain.input_fn, opts)
113113

114+
if not opts['checkpoint_epochs']:
115+
logging.info("unable to ensure pure no-checkpoint behavior with resnet in pure tensorflow, removing result directory")
116+
import shutil
117+
shutil.rmtree(model_dir)
118+
114119
return history, timings, { 'num_weights' : None }

models/tf_details/resnet_run_loop.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -327,6 +327,7 @@ def resnet_main(flags, model_function, input_function, opts = None):
327327
logging.warning("batch sizes differ in model %i %s", flags.batch_size, opts["batch_size"])
328328

329329
if ngpus > 1:
330+
steps_per_epoch -= 1
330331
validate_batch_size_for_multi_gpu(flags.batch_size)
331332
# There are two steps required if using multi-GPU: (1) wrap the model_fn,
332333
# and (2) wrap the optimizer. The first happens here, and (2) happens
@@ -407,12 +408,6 @@ def input_fn_eval():
407408
validation_results = classifier.evaluate(input_fn=input_fn_eval,
408409
steps=flags.max_train_steps)
409410

410-
411-
# for (k,v) in train_hooks["CaptureTensorsHook"].captured.items():
412-
# print(">> ",k,v[:5],v[-2:])
413-
414-
#epoch_times.extend(train_hooks["TimePerEpochHook"].epoch_durations)
415-
416411
for k in validation_results.keys():
417412
if "global_step" in k:
418413
continue

models/tf_details/utils/logging/hooks.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def add(self,other):
5050
class TimePerEpochHook(tf.train.SessionRunHook):
5151
def __init__(self,
5252
every_n_steps,
53-
warm_steps=0):
53+
warm_steps=-1):
5454

5555
self.every_n_steps = every_n_steps
5656
logging.info("TimePerEpochHook triggering every %i steps",every_n_steps)
@@ -112,8 +112,8 @@ def after_run(self, run_context, run_values): # pylint: disable=unused-argument
112112
global_step = run_values.results
113113
sess = run_context.session
114114

115-
116-
if self._timer.should_trigger_for_step(global_step) and global_step > self._warm_steps:
115+
#if self._timer.should_trigger_for_step(global_step) and global_step > self._warm_steps:
116+
if self._step % self.every_n_steps == 0:
117117
elapsed_time, elapsed_steps = self._timer.update_last_triggered_step(
118118
global_step)
119119
if elapsed_time is not None:
@@ -124,6 +124,8 @@ def after_run(self, run_context, run_values): # pylint: disable=unused-argument
124124
tf.logging.info('Epoch [%g steps]: %g (%s)', self._total_steps,self._epoch_train_time,str(self.epoch_durations))
125125

126126
self._epoch_train_time = 0
127+
else:
128+
logging.warning("step %i, elapsed_time is None!", global_step)
127129

128130

129131

0 commit comments

Comments
 (0)