Skip to content

Commit ff41c20

Browse files
authored
Update retinaNet + RCNN to use 'classification' (#1313)
1 parent ec39676 commit ff41c20

File tree

2 files changed

+12
-12
lines changed

2 files changed

+12
-12
lines changed

keras_cv/models/object_detection/faster_rcnn.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -429,9 +429,9 @@ def compile(
429429
self.weight_decay = weight_decay
430430
losses = {
431431
"box": self.box_loss,
432-
"cls": self.cls_loss,
432+
"classification": self.cls_loss,
433433
"rpn_box": self.rpn_box_loss,
434-
"rpn_cls": self.rpn_cls_loss,
434+
"rpn_classification": self.rpn_cls_loss,
435435
}
436436
super().compile(loss=losses, **kwargs)
437437

@@ -466,21 +466,21 @@ def compute_loss(self, images, boxes, classes, training):
466466
box_pred, cls_pred = self._call_rcnn(rois, feature_map, training=training)
467467
y_true = {
468468
"rpn_box": rpn_box_targets,
469-
"rpn_cls": rpn_cls_targets,
469+
"rpn_classification": rpn_cls_targets,
470470
"box": box_targets,
471-
"cls": cls_targets,
471+
"classification": cls_targets,
472472
}
473473
y_pred = {
474474
"rpn_box": rpn_box_pred,
475-
"rpn_cls": rpn_cls_pred,
475+
"rpn_classification": rpn_cls_pred,
476476
"box": box_pred,
477-
"cls": cls_pred,
477+
"classification": cls_pred,
478478
}
479479
weights = {
480480
"rpn_box": rpn_box_weights,
481-
"rpn_cls": rpn_cls_weights,
481+
"rpn_classification": rpn_cls_weights,
482482
"box": box_weights,
483-
"cls": cls_weights,
483+
"classification": cls_weights,
484484
}
485485
return super().compute_loss(
486486
x=images, y=y_true, y_pred=y_pred, sample_weight=weights

keras_cv/models/object_detection/retina_net/retina_net.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -312,7 +312,7 @@ def compile(
312312
self.weight_decay = weight_decay
313313
losses = {
314314
"box": self.box_loss,
315-
"cls": self.classification_loss,
315+
"classification": self.classification_loss,
316316
}
317317
super().compile(loss=losses, **kwargs)
318318

@@ -350,15 +350,15 @@ def compute_loss(self, images, boxes, classes, training):
350350
box_weights = positive_mask / normalizer
351351
y_true = {
352352
"box": boxes,
353-
"cls": cls_labels,
353+
"classification": cls_labels,
354354
}
355355
y_pred = {
356356
"box": box_pred,
357-
"cls": cls_pred,
357+
"classification": cls_pred,
358358
}
359359
sample_weights = {
360360
"box": box_weights,
361-
"cls": cls_weights,
361+
"classification": cls_weights,
362362
}
363363
return super().compute_loss(
364364
x=images, y=y_true, y_pred=y_pred, sample_weight=sample_weights

0 commit comments

Comments
 (0)