Skip to content

Commit 0eb4a41

Browse files
authored
fix training example script (#1302)
1 parent 4426645 commit 0eb4a41

File tree

1 file changed

+23
-27
lines changed

1 file changed

+23
-27
lines changed

examples/training/object_detection/pascal_voc/retina_net.py

Lines changed: 23 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -189,17 +189,17 @@ def get_non_empty_box_indices(boxes):
189189
return indices[:, 0]
190190

191191

192-
def resize_fn(image, gt_boxes, gt_classes):
192+
def resize_fn(image, boxes, classes):
193193
image, image_info = resize_and_crop_image(
194194
image, image_size[:2], image_size[:2], 0.8, 1.25
195195
)
196-
gt_boxes = resize_and_crop_boxes(
197-
gt_boxes, image_info[2, :], image_info[1, :], image_info[3, :]
196+
boxes = resize_and_crop_boxes(
197+
boxes, image_info[2, :], image_info[1, :], image_info[3, :]
198198
)
199-
indices = get_non_empty_box_indices(gt_boxes)
200-
gt_boxes = tf.gather(gt_boxes, indices)
201-
gt_classes = tf.gather(gt_classes, indices)
202-
return image, gt_boxes, gt_classes
199+
indices = get_non_empty_box_indices(boxes)
200+
boxes = tf.gather(boxes, indices)
201+
classes = tf.gather(classes, indices)
202+
return image, boxes, classes
203203

204204

205205
def flip_fn(image, boxes):
@@ -214,21 +214,20 @@ def proc_train_fn(bounding_box_format, img_size):
214214
def apply(inputs):
215215
image = inputs["image"]
216216
image = tf.cast(image, tf.float32)
217-
gt_boxes = inputs["objects"]["bbox"]
218-
image, gt_boxes = flip_fn(image, gt_boxes)
219-
gt_boxes = keras_cv.bounding_box.convert_format(
220-
gt_boxes,
217+
boxes = inputs["objects"]["bbox"]
218+
image, boxes = flip_fn(image, boxes)
219+
boxes = keras_cv.bounding_box.convert_format(
220+
boxes,
221221
images=image,
222222
source="rel_yxyx",
223223
target="yxyx",
224224
)
225-
gt_classes = tf.cast(inputs["objects"]["label"], tf.float32)
226-
image, gt_boxes, gt_classes = resize_fn(image, gt_boxes, gt_classes)
227-
gt_classes = tf.expand_dims(gt_classes, axis=-1)
225+
classes = tf.cast(inputs["objects"]["label"], tf.float32)
226+
image, boxes, classes = resize_fn(image, boxes, classes)
228227
bounding_boxes = keras_cv.bounding_box.convert_format(
229-
gt_boxes, images=image, source="yxyx", target=bounding_box_format
228+
boxes, images=image, source="yxyx", target=bounding_box_format
230229
)
231-
bounding_boxes = tf.concat([gt_boxes, gt_classes], axis=-1)
230+
bounding_boxes = {"boxes": boxes, "classes": classes}
232231
return image, bounding_boxes
233232

234233
return apply
@@ -257,7 +256,7 @@ def apply(inputs):
257256
raw_image, (target_height, target_width), antialias=False
258257
)
259258

260-
gt_boxes = keras_cv.bounding_box.convert_format(
259+
boxes = keras_cv.bounding_box.convert_format(
261260
inputs["objects"]["bbox"],
262261
images=image,
263262
source="rel_yxyx",
@@ -266,27 +265,24 @@ def apply(inputs):
266265
image = tf.image.pad_to_bounding_box(
267266
image, 0, 0, target_size[0], target_size[1]
268267
)
269-
gt_boxes = keras_cv.bounding_box.convert_format(
270-
gt_boxes,
268+
boxes = keras_cv.bounding_box.convert_format(
269+
boxes,
271270
images=image,
272271
source="xyxy",
273272
target=bounding_box_format,
274273
)
275-
gt_classes = tf.cast(inputs["objects"]["label"], tf.float32)
276-
gt_classes = tf.expand_dims(gt_classes, axis=-1)
277-
278-
bounding_boxes = tf.concat([gt_boxes, gt_classes], axis=-1)
279-
274+
classes = tf.cast(inputs["objects"]["label"], tf.float32)
275+
bounding_boxes = {"boxes": boxes, "classes": classes}
280276
return image, bounding_boxes
281277

282278
return apply
283279

284280

285281
def pad_fn(images, boxes):
286282
boxes = boxes.to_tensor(default_value=-1.0, shape=[GLOBAL_BATCH_SIZE, 32, 5])
287-
gt_boxes = boxes[:, :, :4]
288-
gt_classes = boxes[:, :, 4]
289-
return images, {"boxes": gt_boxes, "classes": gt_classes}
283+
boxes = boxes[:, :, :4]
284+
classes = boxes[:, :, 4]
285+
return images, {"boxes": boxes, "classes": classes}
290286

291287

292288
train_ds = train_ds.map(

0 commit comments

Comments
 (0)