Skip to content

Commit cece47e

Browse files
authored
Remove gt_ prefix from OD models compute_loss (#1304)
1 parent 0eb4a41 commit cece47e

File tree

2 files changed

+25
-25
lines changed

2 files changed

+25
-25
lines changed

keras_cv/models/object_detection/faster_rcnn.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -437,7 +437,7 @@ def compile(
437437
}
438438
super().compile(loss=losses, **kwargs)
439439

440-
def compute_loss(self, images, gt_boxes, gt_classes, training):
440+
def compute_loss(self, images, boxes, classes, training):
441441
image_shape = tf.shape(images[0])
442442
local_batch = images.get_shape().as_list()[0]
443443
if tf.distribute.has_strategy():
@@ -452,7 +452,7 @@ def compute_loss(self, images, gt_boxes, gt_classes, training):
452452
rpn_cls_targets,
453453
rpn_cls_weights,
454454
) = self.rpn_labeler(
455-
tf.concat(tf.nest.flatten(anchors), axis=0), gt_boxes, gt_classes
455+
tf.concat(tf.nest.flatten(anchors), axis=0), boxes, classes
456456
)
457457
rpn_box_weights /= self.rpn_labeler.samples_per_image * global_batch * 0.25
458458
rpn_cls_weights /= self.rpn_labeler.samples_per_image * global_batch
@@ -461,7 +461,7 @@ def compute_loss(self, images, gt_boxes, gt_classes, training):
461461
)
462462
rois = tf.stop_gradient(rois)
463463
rois, box_targets, box_weights, cls_targets, cls_weights = self.roi_sampler(
464-
rois, gt_boxes, gt_classes
464+
rois, boxes, classes
465465
)
466466
box_weights /= self.roi_sampler.num_sampled_rois * global_batch * 0.25
467467
cls_weights /= self.roi_sampler.num_sampled_rois * global_batch
@@ -492,16 +492,16 @@ def train_step(self, data):
492492
images, y, sample_weight = tf.keras.utils.unpack_x_y_sample_weight(data)
493493
if sample_weight is not None:
494494
raise ValueError("`sample_weight` is currently not supported.")
495-
gt_boxes = y["boxes"]
495+
boxes = y["boxes"]
496496
if len(y["classes"].shape) != 2:
497497
raise ValueError(
498498
"Expected 'classes' to be a tf.Tensor of rank 2. "
499499
f"Got y['classes'].shape={y['classes'].shape}."
500500
)
501501
# TODO(tanzhenyu): remove this hack and perform broadcasting elsewhere
502-
gt_classes = tf.expand_dims(y["classes"], axis=-1)
502+
classes = tf.expand_dims(y["classes"], axis=-1)
503503
with tf.GradientTape() as tape:
504-
total_loss = self.compute_loss(images, gt_boxes, gt_classes, training=True)
504+
total_loss = self.compute_loss(images, boxes, classes, training=True)
505505
reg_losses = []
506506
if self.weight_decay:
507507
for var in self.trainable_variables:
@@ -516,14 +516,14 @@ def test_step(self, data):
516516
images, y, sample_weight = tf.keras.utils.unpack_x_y_sample_weight(data)
517517
if sample_weight is not None:
518518
raise ValueError("`sample_weight` is currently not supported.")
519-
gt_boxes = y["boxes"]
519+
boxes = y["boxes"]
520520
if len(y["classes"].shape) != 2:
521521
raise ValueError(
522522
"Expected 'classes' to be a tf.Tensor of rank 2. "
523523
f"Got y['classes'].shape={y['classes'].shape}."
524524
)
525-
gt_classes = tf.expand_dims(y["classes"], axis=-1)
526-
self.compute_loss(images, gt_boxes, gt_classes, training=False)
525+
classes = tf.expand_dims(y["classes"], axis=-1)
526+
self.compute_loss(images, boxes, classes, training=False)
527527
return self.compute_metrics(images, {}, {}, sample_weight={})
528528

529529
def make_predict_function(self, force=False):

keras_cv/models/object_detection/retina_net/retina_net.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -316,12 +316,12 @@ def compile(
316316
}
317317
super().compile(loss=losses, **kwargs)
318318

319-
def compute_loss(self, images, gt_boxes, gt_classes, training):
319+
def compute_loss(self, images, boxes, classes, training):
320320
box_pred, cls_pred = self._forward(images, training=training)
321-
if gt_boxes.shape[-1] != 4:
321+
if boxes.shape[-1] != 4:
322322
raise ValueError(
323-
"gt_boxes should have shape (None, None, 4). Got "
324-
f"gt_boxes.shape={tuple(gt_boxes.shape)}"
323+
"boxes should have shape (None, None, 4). Got "
324+
f"boxes.shape={tuple(boxes.shape)}"
325325
)
326326

327327
if box_pred.shape[-1] != 4:
@@ -338,18 +338,18 @@ def compute_loss(self, images, gt_boxes, gt_classes, training):
338338
)
339339

340340
cls_labels = tf.one_hot(
341-
tf.cast(gt_classes, dtype=tf.int32),
341+
tf.cast(classes, dtype=tf.int32),
342342
depth=self.classes,
343343
dtype=tf.float32,
344344
)
345345

346-
positive_mask = tf.cast(tf.greater(gt_classes, -1.0), dtype=tf.float32)
346+
positive_mask = tf.cast(tf.greater(classes, -1.0), dtype=tf.float32)
347347
normalizer = tf.reduce_sum(positive_mask)
348-
cls_weights = tf.cast(tf.math.not_equal(gt_classes, -2.0), dtype=tf.float32)
348+
cls_weights = tf.cast(tf.math.not_equal(classes, -2.0), dtype=tf.float32)
349349
cls_weights /= normalizer
350350
box_weights = positive_mask / normalizer
351351
y_true = {
352-
"box": gt_boxes,
352+
"box": boxes,
353353
"cls": cls_labels,
354354
}
355355
y_pred = {
@@ -372,16 +372,16 @@ def train_step(self, data):
372372
target=self.label_encoder.bounding_box_format,
373373
images=x,
374374
)
375-
gt_boxes, gt_classes = self.label_encoder(x, y)
376-
gt_boxes = bounding_box.convert_format(
377-
gt_boxes,
375+
boxes, classes = self.label_encoder(x, y)
376+
boxes = bounding_box.convert_format(
377+
boxes,
378378
source=self.label_encoder.bounding_box_format,
379379
target=self.bounding_box_format,
380380
images=x,
381381
)
382382

383383
with tf.GradientTape() as tape:
384-
total_loss = self.compute_loss(x, gt_boxes, gt_classes, training=True)
384+
total_loss = self.compute_loss(x, boxes, classes, training=True)
385385

386386
reg_losses = []
387387
if self.weight_decay:
@@ -405,14 +405,14 @@ def test_step(self, data):
405405
target=self.label_encoder.bounding_box_format,
406406
images=x,
407407
)
408-
gt_boxes, gt_classes = self.label_encoder(x, y)
409-
gt_boxes = bounding_box.convert_format(
410-
gt_boxes,
408+
boxes, classes = self.label_encoder(x, y)
409+
boxes = bounding_box.convert_format(
410+
boxes,
411411
source=self.label_encoder.bounding_box_format,
412412
target=self.bounding_box_format,
413413
images=x,
414414
)
415-
_ = self.compute_loss(x, gt_boxes, gt_classes, training=False)
415+
_ = self.compute_loss(x, boxes, classes, training=False)
416416

417417
return self.compute_metrics(x, {}, {}, sample_weight={})
418418

0 commit comments

Comments
 (0)