@@ -189,17 +189,17 @@ def get_non_empty_box_indices(boxes):
189189 return indices [:, 0 ]
190190
191191
192- def resize_fn (image , gt_boxes , gt_classes ):
192+ def resize_fn (image , boxes , classes ):
193193 image , image_info = resize_and_crop_image (
194194 image , image_size [:2 ], image_size [:2 ], 0.8 , 1.25
195195 )
196- gt_boxes = resize_and_crop_boxes (
197- gt_boxes , image_info [2 , :], image_info [1 , :], image_info [3 , :]
196+ boxes = resize_and_crop_boxes (
197+ boxes , image_info [2 , :], image_info [1 , :], image_info [3 , :]
198198 )
199- indices = get_non_empty_box_indices (gt_boxes )
200- gt_boxes = tf .gather (gt_boxes , indices )
201- gt_classes = tf .gather (gt_classes , indices )
202- return image , gt_boxes , gt_classes
199+ indices = get_non_empty_box_indices (boxes )
200+ boxes = tf .gather (boxes , indices )
201+ classes = tf .gather (classes , indices )
202+ return image , boxes , classes
203203
204204
205205def flip_fn (image , boxes ):
@@ -214,21 +214,20 @@ def proc_train_fn(bounding_box_format, img_size):
214214 def apply (inputs ):
215215 image = inputs ["image" ]
216216 image = tf .cast (image , tf .float32 )
217- gt_boxes = inputs ["objects" ]["bbox" ]
218- image , gt_boxes = flip_fn (image , gt_boxes )
219- gt_boxes = keras_cv .bounding_box .convert_format (
220- gt_boxes ,
217+ boxes = inputs ["objects" ]["bbox" ]
218+ image , boxes = flip_fn (image , boxes )
219+ boxes = keras_cv .bounding_box .convert_format (
220+ boxes ,
221221 images = image ,
222222 source = "rel_yxyx" ,
223223 target = "yxyx" ,
224224 )
225- gt_classes = tf .cast (inputs ["objects" ]["label" ], tf .float32 )
226- image , gt_boxes , gt_classes = resize_fn (image , gt_boxes , gt_classes )
227- gt_classes = tf .expand_dims (gt_classes , axis = - 1 )
225+ classes = tf .cast (inputs ["objects" ]["label" ], tf .float32 )
226+ image , boxes , classes = resize_fn (image , boxes , classes )
228227 bounding_boxes = keras_cv .bounding_box .convert_format (
229- gt_boxes , images = image , source = "yxyx" , target = bounding_box_format
228+ boxes , images = image , source = "yxyx" , target = bounding_box_format
230229 )
231- bounding_boxes = tf . concat ([ gt_boxes , gt_classes ], axis = - 1 )
230+ bounding_boxes = { "boxes" : boxes , "classes" : classes }
232231 return image , bounding_boxes
233232
234233 return apply
@@ -257,7 +256,7 @@ def apply(inputs):
257256 raw_image , (target_height , target_width ), antialias = False
258257 )
259258
260- gt_boxes = keras_cv .bounding_box .convert_format (
259+ boxes = keras_cv .bounding_box .convert_format (
261260 inputs ["objects" ]["bbox" ],
262261 images = image ,
263262 source = "rel_yxyx" ,
@@ -266,27 +265,24 @@ def apply(inputs):
266265 image = tf .image .pad_to_bounding_box (
267266 image , 0 , 0 , target_size [0 ], target_size [1 ]
268267 )
269- gt_boxes = keras_cv .bounding_box .convert_format (
270- gt_boxes ,
268+ boxes = keras_cv .bounding_box .convert_format (
269+ boxes ,
271270 images = image ,
272271 source = "xyxy" ,
273272 target = bounding_box_format ,
274273 )
275- gt_classes = tf .cast (inputs ["objects" ]["label" ], tf .float32 )
276- gt_classes = tf .expand_dims (gt_classes , axis = - 1 )
277-
278- bounding_boxes = tf .concat ([gt_boxes , gt_classes ], axis = - 1 )
279-
274+ classes = tf .cast (inputs ["objects" ]["label" ], tf .float32 )
275+ bounding_boxes = {"boxes" : boxes , "classes" : classes }
280276 return image , bounding_boxes
281277
282278 return apply
283279
284280
285281def pad_fn (images , boxes ):
286282 boxes = boxes .to_tensor (default_value = - 1.0 , shape = [GLOBAL_BATCH_SIZE , 32 , 5 ])
287- gt_boxes = boxes [:, :, :4 ]
288- gt_classes = boxes [:, :, 4 ]
289- return images , {"boxes" : gt_boxes , "classes" : gt_classes }
283+ boxes = boxes [:, :, :4 ]
284+ classes = boxes [:, :, 4 ]
285+ return images , {"boxes" : boxes , "classes" : classes }
290286
291287
292288train_ds = train_ds .map (
0 commit comments