1616import tensorflow_datasets as tfds
1717
1818from keras_cv import bounding_box
19- from keras_cv import layers as cv_layers
2019
2120
22- def curry_map_function (bounding_box_format , img_size ):
21+ def curry_map_function (bounding_box_format ):
2322 """Mapping function to create batched image and bbox coordinates"""
2423
25- if img_size is not None :
26- resizing = cv_layers .Resizing (
27- height = img_size [0 ],
28- width = img_size [1 ],
29- bounding_box_format = bounding_box_format ,
30- pad_to_aspect_ratio = True ,
31- )
32-
3324 def apply (inputs ):
3425 images = inputs ["image" ]
3526 bounding_boxes = inputs ["objects" ]["bbox" ]
36- labels = tf .cast (inputs ["objects" ]["label" ], tf .float32 )
37- labels = tf .expand_dims (labels , axis = - 1 )
38- bounding_boxes = tf .concat ([bounding_boxes , labels ], axis = - 1 )
27+ labels = inputs ["objects" ]["label" ]
3928 bounding_boxes = bounding_box .convert_format (
4029 bounding_boxes ,
4130 images = images ,
4231 source = "rel_yxyx" ,
4332 target = bounding_box_format ,
4433 )
4534
35+ bounding_boxes = {"boxes" : bounding_boxes , "classes" : labels }
36+
4637 outputs = {"images" : images , "bounding_boxes" : bounding_boxes }
47- if img_size is not None :
48- outputs = resizing (outputs )
4938 return outputs
5039
5140 return apply
@@ -57,7 +46,7 @@ def load(
5746 batch_size = None ,
5847 shuffle_files = True ,
5948 shuffle_buffer = None ,
60- img_size = None ,
49+ dataset = "voc/2007" ,
6150):
6251 """Loads the PascalVOC 2007 dataset.
6352
@@ -79,20 +68,26 @@ def load(
7968 shuffle: whether or not to shuffle the dataset. Defaults to True.
8069 shuffle_buffer: the size of the buffer to use in shuffling.
8170 shuffle_files: (Optional) whether or not to shuffle files, defaults to True.
82- img_size : (Optional) size to resize the images to, if not provided image batches
83- will be of type `tf.RaggedTensor` .
71+ dataset : (Optional) the PascalVOC dataset to load from. Should be either
72+ 'voc/2007' or 'voc/2012'. Defaults to 'voc/2007' .
8473
8574 Returns:
8675 tf.data.Dataset containing PascalVOC. Each entry is a dictionary containing
8776 keys {"images": images, "bounding_boxes": bounding_boxes} where images is a
8877 Tensor of shape [batch, H, W, 3] and bounding_boxes is a `tf.RaggedTensor` of
8978 shape [batch, None, 5].
9079 """
80+ if dataset not in ["voc/2007" , "voc/2012" ]:
81+ raise ValueError (
82+ "keras_cv.datasets.pascal_voc.load() expects the `dataset` "
83+ "argument to be either 'voc/2007' or 'voc/2012', but got "
84+ f"`dataset={ dataset } `."
85+ )
9186 dataset , dataset_info = tfds .load (
92- "voc/2007" , split = split , shuffle_files = shuffle_files , with_info = True
87+ dataset , split = split , shuffle_files = shuffle_files , with_info = True
9388 )
9489 dataset = dataset .map (
95- curry_map_function (bounding_box_format = bounding_box_format , img_size = img_size ),
90+ curry_map_function (bounding_box_format = bounding_box_format ),
9691 num_parallel_calls = tf .data .AUTOTUNE ,
9792 )
9893
0 commit comments