2323 "horses_or_humans" ,
2424 "TFDS Dataset Name. IMAGE Dimension should be >= 224, channel=3" )
2525flags .DEFINE_string ("data_dir" , None , "Directory to Save Data to" )
26+ flags .DEFINE_string ("infer" , None , "Dummy image file to infer" )
2627
2728FLAGS = flags .FLAGS
2829NUM_CLASSES = None
2930
3031
32+ def resize_and_scale (image , label ):
33+ image = tf .image .resize (image , size = [224 , 224 ])
34+ image = tf .cast (image , tf .float32 )
35+ image = image / tf .reduce_max (tf .gather (image , 0 ))
36+ return image , label
37+
3138def input_ (mode , batch_size , iterations , ** kwargs ):
3239 global NUM_CLASSES
3340 dataset , info = tfds .load (
@@ -38,13 +45,6 @@ def input_(mode, batch_size, iterations, **kwargs):
3845 data_dir = kwargs ['data_dir' ]
3946 )
4047 NUM_CLASSES = info .features ['label' ].num_classes
41-
42- def resize_and_scale (image , label ):
43- image = tf .image .resize (image , size = [224 , 224 ])
44- image = tf .cast (image , tf .float32 )
45- image = image / tf .reduce_max (tf .gather (image , 0 ))
46- return image , label
47-
4848 dataset = dataset .map (resize_and_scale ).shuffle (
4949 1000 ).repeat (iterations ).batch (batch_size , drop_remainder = True )
5050 return dataset
@@ -135,9 +135,16 @@ def main(_):
135135 input_fn = lambda params : input_fn (
136136 mode = tf .estimator .ModeKeys .TRAIN ,
137137 ** params ),
138- max_steps = 1000 )
138+ max_steps = None , steps = None )
139139 # TODO(@captain-pool): Implement Evaluation
140-
140+ if FLAGS .infer :
141+ def prepare_input_fn (path ):
142+ img = tf .image .decode_image (tf .io .read_file (path ))
143+ return resize_and_scale (img , None )
144+
145+ predictions = classifer .predict (
146+ input_fn = lambda params : prepare_input_fn (FLAGS .infer ))
147+ print (predictions )
141148
142149if __name__ == "__main__" :
143150 app .run (main )
0 commit comments