Skip to content
This repository was archived by the owner on Oct 31, 2025. It is now read-only.

Commit ee3369b

Browse files
committed
Added TPU Estimator Prediction
1 parent 8370b80 commit ee3369b

File tree

2 files changed

+16
-9
lines changed

2 files changed

+16
-9
lines changed

E1_TPU_Sample/image_retraining_tpu.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,18 @@
2323
"horses_or_humans",
2424
"TFDS Dataset Name. IMAGE Dimension should be >= 224, channel=3")
2525
flags.DEFINE_string("data_dir", None, "Directory to Save Data to")
26+
flags.DEFINE_string("infer", None, "Dummy image file to infer")
2627

2728
FLAGS = flags.FLAGS
2829
NUM_CLASSES = None
2930

3031

32+
def resize_and_scale(image, label):
33+
image = tf.image.resize(image, size=[224, 224])
34+
image = tf.cast(image, tf.float32)
35+
image = image / tf.reduce_max(tf.gather(image, 0))
36+
return image, label
37+
3138
def input_(mode, batch_size, iterations, **kwargs):
3239
global NUM_CLASSES
3340
dataset, info = tfds.load(
@@ -38,13 +45,6 @@ def input_(mode, batch_size, iterations, **kwargs):
3845
data_dir=kwargs['data_dir']
3946
)
4047
NUM_CLASSES = info.features['label'].num_classes
41-
42-
def resize_and_scale(image, label):
43-
image = tf.image.resize(image, size=[224, 224])
44-
image = tf.cast(image, tf.float32)
45-
image = image / tf.reduce_max(tf.gather(image, 0))
46-
return image, label
47-
4848
dataset = dataset.map(resize_and_scale).shuffle(
4949
1000).repeat(iterations).batch(batch_size, drop_remainder=True)
5050
return dataset
@@ -135,9 +135,16 @@ def main(_):
135135
input_fn=lambda params: input_fn(
136136
mode=tf.estimator.ModeKeys.TRAIN,
137137
**params),
138-
max_steps=1000)
138+
max_steps=None, steps=None)
139139
# TODO(@captain-pool): Implement Evaluation
140-
140+
if FLAGS.infer:
141+
def prepare_input_fn(path):
142+
img = tf.image.decode_image(tf.io.read_file(path))
143+
return resize_and_scale(img, None)
144+
145+
predictions = classifer.predict(
146+
input_fn=lambda params: prepare_input_fn(FLAGS.infer))
147+
print(predictions)
141148

142149
if __name__ == "__main__":
143150
app.run(main)

E1_TPU_Sample/testdata/0.jpg

7.38 KB
Loading

0 commit comments

Comments
 (0)