3333
3434from dataset import tfrecord_util
3535
36- flags .DEFINE_string ('data_dir' , '' , 'Root directory to raw PASCAL VOC dataset.' )
37- flags .DEFINE_string ('set' , 'train' , 'Convert training set, validation set or '
38- 'merged set.' )
39- flags .DEFINE_string ('annotations_dir' , 'Annotations' ,
40- '(Relative) path to annotations directory.' )
41- flags .DEFINE_string ('year' , 'VOC2007' , 'Desired challenge year.' )
42- flags .DEFINE_string ('output_path' , '' , 'Path to output TFRecord and json.' )
43- flags .DEFINE_string ('label_map_json_path' , None ,
44- 'Path to label map json file with a dictionary.' )
45- flags .DEFINE_boolean ('ignore_difficult_instances' , False , 'Whether to ignore '
46- 'difficult instances' )
47- flags .DEFINE_integer ('num_shards' , 100 , 'Number of shards for output file.' )
48- flags .DEFINE_integer ('num_images' , None , 'Max number of imags to process.' )
4936FLAGS = flags .FLAGS
5037
5138SETS = ['train' , 'val' , 'trainval' , 'test' ]
7966GLOBAL_ANN_ID = 0 # global annotation id.
8067
8168
69+ def define_flags ():
70+ """Define the flags."""
71+ flags .DEFINE_string ('data_dir' , '' ,
72+ 'Root directory to raw PASCAL VOC dataset.' )
73+ flags .DEFINE_string ('set' , 'train' , 'Convert training set, validation set or '
74+ 'merged set.' )
75+ flags .DEFINE_string ('annotations_dir' , 'Annotations' ,
76+ '(Relative) path to annotations directory.' )
77+ flags .DEFINE_string ('year' , 'VOC2007' , 'Desired challenge year.' )
78+ flags .DEFINE_string ('output_path' , '' , 'Path to output TFRecord and json.' )
79+ flags .DEFINE_string ('label_map_json_path' , None ,
80+ 'Path to label map json file with a dictionary.' )
81+ flags .DEFINE_boolean ('ignore_difficult_instances' , False , 'Whether to ignore '
82+ 'difficult instances' )
83+ flags .DEFINE_integer ('num_shards' , 100 , 'Number of shards for output file.' )
84+ flags .DEFINE_integer ('num_images' , None , 'Max number of imags to process.' )
85+
86+
8287def get_image_id (filename ):
8388 """Convert a string to a integer."""
8489 # Warning: this function is highly specific to pascal filename!!
@@ -101,10 +106,9 @@ def get_ann_id():
101106
102107
103108def dict_to_tf_example (data ,
104- dataset_directory ,
109+ images_dir ,
105110 label_map_dict ,
106111 ignore_difficult_instances = False ,
107- image_subdirectory = 'JPEGImages' ,
108112 ann_json_dict = None ):
109113 """Convert XML derived dict to tf.Example proto.
110114
@@ -114,12 +118,10 @@ def dict_to_tf_example(data,
114118 Args:
115119 data: dict holding PASCAL XML fields for a single image (obtained by running
116120 tfrecord_util.recursive_parse_xml_to_dict)
117- dataset_directory : Path to root directory holding PASCAL dataset
121+ images_dir : Path to the directory holding raw images.
118122 label_map_dict: A map from string label names to integers ids.
119123 ignore_difficult_instances: Whether to skip difficult instances in the
120124 dataset (default: False).
121- image_subdirectory: String specifying subdirectory within the PASCAL dataset
122- directory holding the actual image data.
123125 ann_json_dict: annotation json dictionary.
124126
125127 Returns:
@@ -128,8 +130,7 @@ def dict_to_tf_example(data,
128130 Raises:
129131 ValueError: if the image pointed to by data['filename'] is not a valid JPEG
130132 """
131- img_path = os .path .join (data ['folder' ], image_subdirectory , data ['filename' ])
132- full_path = os .path .join (dataset_directory , img_path )
133+ full_path = os .path .join (images_dir , data ['filename' ])
133134 with tf .io .gfile .GFile (full_path , 'rb' ) as fid :
134135 encoded_jpg = fid .read ()
135136 encoded_jpg_io = io .BytesIO (encoded_jpg )
@@ -297,9 +298,10 @@ def main(_):
297298 xml = etree .fromstring (xml_str )
298299 data = tfrecord_util .recursive_parse_xml_to_dict (xml )['annotation' ]
299300
301+ img_dir = os .path .join (FLAGS .data_dir , data ['folder' ], 'JPEGImages' )
300302 tf_example = dict_to_tf_example (
301303 data ,
302- FLAGS . data_dir ,
304+ img_dir ,
303305 label_map_dict ,
304306 FLAGS .ignore_difficult_instances ,
305307 ann_json_dict = ann_json_dict )
@@ -316,4 +318,5 @@ def main(_):
316318
317319
318320if __name__ == '__main__' :
321+ define_flags ()
319322 app .run (main )
0 commit comments