Skip to content
This repository was archived by the owner on Oct 31, 2025. It is now read-only.

Commit 513a0ec

Browse files
committed
Added max_steps parameter
1 parent ee3369b commit 513a0ec

File tree

1 file changed

+11
-8
lines changed

1 file changed

+11
-8
lines changed

E1_TPU_Sample/image_retraining_tpu.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
flags.DEFINE_float("learning_rate", 1e-3, "Learning Rate")
1515
flags.DEFINE_boolean("use_tpu", True, " Use TPU")
1616
flags.DEFINE_boolean("use_compat", True, "Use OptimizerV1 from compat module")
17+
flags.DEFINE_integer("max_steps", 1000, "Maximum Number of Steps for TPU Estimator")
1718
flags.DEFINE_string(
1819
"model_dir",
1920
"model_dir/",
@@ -67,7 +68,7 @@ def model_fn(features, labels, mode, params):
6768
else:
6869
optimizer = tf.compat.v1.train.AdamOptimizer(
6970
params["learning_rate"])
70-
if params.get["use_tpu"]:
71+
if params["use_tpu"]:
7172
optimizer = tpu_optimizer.CrossShardOptimizer(optimizer)
7273

7374
with tf.GradientTape() as tape:
@@ -95,7 +96,7 @@ def train_fn(use_compat):
9596
zip(gradient, model.trainable_variables))
9697
else:
9798
apply_grads = optimizer.apply_gradients(
98-
zip(gradient, model_trainable_variables),
99+
zip(gradient, model.trainable_variables),
99100
global_step=global_step)
100101
return apply_grads
101102

@@ -130,12 +131,14 @@ def main(_):
130131
"learning_rate": FLAGS.learning_rate
131132
}
132133
)
133-
134-
classifier.train(
135-
input_fn=lambda params: input_fn(
136-
mode=tf.estimator.ModeKeys.TRAIN,
137-
**params),
138-
max_steps=None, steps=None)
134+
try:
135+
classifier.train(
136+
input_fn=lambda params: input_fn(
137+
mode=tf.estimator.ModeKeys.TRAIN,
138+
**params),
139+
max_steps=FLAGS.max_steps)
140+
except Exception:
141+
pass
139142
# TODO(@captain-pool): Implement Evaluation
140143
if FLAGS.infer:
141144
def prepare_input_fn(path):

0 commit comments

Comments
 (0)