|
| 1 | +# Copyright (c) Facebook, Inc. and its affiliates. |
| 2 | + |
| 3 | +import logging |
| 4 | +from typing import List, Optional, Tuple |
| 5 | +import torch |
| 6 | +from fvcore.nn import sigmoid_focal_loss_jit |
| 7 | +from torch import Tensor, nn |
| 8 | +from torch.nn import functional as F |
| 9 | + |
| 10 | +from detectron2.layers import batched_nms |
| 11 | +from detectron2.structures import Boxes, ImageList, Instances, pairwise_point_box_distance |
| 12 | +from detectron2.utils.events import get_event_storage |
| 13 | + |
| 14 | +from ..anchor_generator import DefaultAnchorGenerator |
| 15 | +from ..backbone import Backbone |
| 16 | +from ..box_regression import Box2BoxTransformLinear, _dense_box_regression_loss |
| 17 | +from .dense_detector import DenseDetector |
| 18 | +from .retinanet import RetinaNetHead |
| 19 | + |
| 20 | +__all__ = ["FCOS"] |
| 21 | + |
| 22 | + |
| 23 | +logger = logging.getLogger(__name__) |
| 24 | + |
| 25 | + |
| 26 | +class FCOS(DenseDetector): |
| 27 | + """ |
| 28 | + Implement FCOS in :paper:`fcos`. |
| 29 | + """ |
| 30 | + |
| 31 | + def __init__( |
| 32 | + self, |
| 33 | + *, |
| 34 | + backbone: Backbone, |
| 35 | + head: nn.Module, |
| 36 | + head_in_features: Optional[List[str]] = None, |
| 37 | + box2box_transform=None, |
| 38 | + num_classes, |
| 39 | + center_sampling_radius: float = 1.5, |
| 40 | + focal_loss_alpha=0.25, |
| 41 | + focal_loss_gamma=2.0, |
| 42 | + test_score_thresh=0.2, |
| 43 | + test_topk_candidates=1000, |
| 44 | + test_nms_thresh=0.6, |
| 45 | + max_detections_per_image=100, |
| 46 | + pixel_mean, |
| 47 | + pixel_std, |
| 48 | + ): |
| 49 | + """ |
| 50 | + Args: |
| 51 | + center_sampling_radius: radius of the "center" of a groundtruth box, |
| 52 | + within which all anchor points are labeled positive. |
| 53 | + Other arguments mean the same as in :class:`RetinaNet`. |
| 54 | + """ |
| 55 | + super().__init__( |
| 56 | + backbone, head, head_in_features, pixel_mean=pixel_mean, pixel_std=pixel_std |
| 57 | + ) |
| 58 | + |
| 59 | + self.num_classes = num_classes |
| 60 | + |
| 61 | + # FCOS uses one anchor point per location. |
| 62 | + # We represent the anchor point by a box whose size equals the anchor stride. |
| 63 | + feature_shapes = backbone.output_shape() |
| 64 | + fpn_strides = [feature_shapes[k].stride for k in self.head_in_features] |
| 65 | + self.anchor_generator = DefaultAnchorGenerator( |
| 66 | + sizes=[[k] for k in fpn_strides], aspect_ratios=[1.0], strides=fpn_strides |
| 67 | + ) |
| 68 | + |
| 69 | + # FCOS parameterizes box regression by a linear transform, |
| 70 | + # where predictions are normalized by anchor stride (equal to anchor size). |
| 71 | + if box2box_transform is None: |
| 72 | + box2box_transform = Box2BoxTransformLinear(normalize_by_size=True) |
| 73 | + self.box2box_transform = box2box_transform |
| 74 | + |
| 75 | + self.center_sampling_radius = float(center_sampling_radius) |
| 76 | + |
| 77 | + # Loss parameters: |
| 78 | + self.focal_loss_alpha = focal_loss_alpha |
| 79 | + self.focal_loss_gamma = focal_loss_gamma |
| 80 | + |
| 81 | + # Inference parameters: |
| 82 | + self.test_score_thresh = test_score_thresh |
| 83 | + self.test_topk_candidates = test_topk_candidates |
| 84 | + self.test_nms_thresh = test_nms_thresh |
| 85 | + self.max_detections_per_image = max_detections_per_image |
| 86 | + |
| 87 | + def forward_training(self, images, features, predictions, gt_instances): |
| 88 | + # Transpose the Hi*Wi*A dimension to the middle: |
| 89 | + pred_logits, pred_anchor_deltas, pred_centerness = self._transpose_dense_predictions( |
| 90 | + predictions, [self.num_classes, 4, 1] |
| 91 | + ) |
| 92 | + anchors = self.anchor_generator(features) |
| 93 | + gt_labels, gt_boxes = self.label_anchors(anchors, gt_instances) |
| 94 | + return self.losses( |
| 95 | + anchors, pred_logits, gt_labels, pred_anchor_deltas, gt_boxes, pred_centerness |
| 96 | + ) |
| 97 | + |
| 98 | + @torch.no_grad() |
| 99 | + def match_anchors(self, anchors: List[Boxes], gt_instances: List[Instances]): |
| 100 | + """ |
| 101 | + Match anchors with ground truth boxes. |
| 102 | +
|
| 103 | + Args: |
| 104 | + anchors: #level boxes, from the highest resolution to lower resolution |
| 105 | + gt_instances: ground truth instances per image |
| 106 | +
|
| 107 | + Returns: |
| 108 | + List[Tensor]: |
| 109 | + #image tensors, each is a vector of matched gt |
| 110 | + indices (or -1 for unmatched anchors) for all anchors. |
| 111 | + """ |
| 112 | + num_anchors_per_level = [len(x) for x in anchors] |
| 113 | + anchors = Boxes.cat(anchors) # Rx4 |
| 114 | + anchor_centers = anchors.get_centers() # Rx2 |
| 115 | + anchor_sizes = anchors.tensor[:, 2] - anchors.tensor[:, 0] # R |
| 116 | + |
| 117 | + lower_bound = anchor_sizes * 4 |
| 118 | + lower_bound[: num_anchors_per_level[0]] = 0 |
| 119 | + upper_bound = anchor_sizes * 8 |
| 120 | + upper_bound[-num_anchors_per_level[-1] :] = float("inf") |
| 121 | + |
| 122 | + matched_indices = [] |
| 123 | + for gt_per_image in gt_instances: |
| 124 | + gt_centers = gt_per_image.gt_boxes.get_centers() # Nx2 |
| 125 | + # FCOS with center sampling: anchor point must be close enough to gt center. |
| 126 | + pairwise_match = (anchor_centers[:, None, :] - gt_centers[None, :, :]).abs_().max( |
| 127 | + dim=2 |
| 128 | + ).values < self.center_sampling_radius * anchor_sizes[:, None] |
| 129 | + pairwise_dist = pairwise_point_box_distance(anchor_centers, gt_per_image.gt_boxes) |
| 130 | + |
| 131 | + # The original FCOS anchor matching rule: anchor point must be inside gt |
| 132 | + pairwise_match &= pairwise_dist.min(dim=2).values > 0 |
| 133 | + |
| 134 | + # Multilevel anchor matching in FCOS: each anchor is only responsible |
| 135 | + # for certain scale range. |
| 136 | + pairwise_dist = pairwise_dist.max(dim=2).values |
| 137 | + pairwise_match &= (pairwise_dist > lower_bound[:, None]) & ( |
| 138 | + pairwise_dist < upper_bound[:, None] |
| 139 | + ) |
| 140 | + |
| 141 | + # Match the GT box with minimum area, if there are multiple GT matches |
| 142 | + gt_areas = gt_per_image.gt_boxes.area() # N |
| 143 | + pairwise_match = pairwise_match.to(torch.float32) * (1e8 - gt_areas[None, :]) |
| 144 | + min_values, matched_idx = pairwise_match.max(dim=1) # R, per-anchor match |
| 145 | + matched_idx[min_values < 1e-5] = -1 # Unmatched anchors are assigned -1 |
| 146 | + |
| 147 | + matched_indices.append(matched_idx) |
| 148 | + return matched_indices |
| 149 | + |
| 150 | + @torch.no_grad() |
| 151 | + def label_anchors(self, anchors, gt_instances): |
| 152 | + """ |
| 153 | + Same interface as :meth:`RetinaNet.label_anchors`, but implemented with FCOS |
| 154 | + anchor matching rule. |
| 155 | +
|
| 156 | + Unlike RetinaNet, there are no ignored anchors. |
| 157 | + """ |
| 158 | + matched_indices = self.match_anchors(anchors, gt_instances) |
| 159 | + |
| 160 | + matched_labels, matched_boxes = [], [] |
| 161 | + for gt_index, gt_per_image in zip(matched_indices, gt_instances): |
| 162 | + label = gt_per_image.gt_classes[gt_index.clip(min=0)] |
| 163 | + label[gt_index < 0] = self.num_classes # background |
| 164 | + |
| 165 | + matched_gt_boxes = gt_per_image.gt_boxes[gt_index.clip(min=0)] |
| 166 | + |
| 167 | + matched_labels.append(label) |
| 168 | + matched_boxes.append(matched_gt_boxes) |
| 169 | + return matched_labels, matched_boxes |
| 170 | + |
| 171 | + def losses( |
| 172 | + self, anchors, pred_logits, gt_labels, pred_anchor_deltas, gt_boxes, pred_centerness |
| 173 | + ): |
| 174 | + """ |
| 175 | + This method is almost identical to :meth:`RetinaNet.losses`, with an extra |
| 176 | + "loss_centerness" in the returned dict. |
| 177 | + """ |
| 178 | + num_images = len(gt_labels) |
| 179 | + gt_labels = torch.stack(gt_labels) # (N, R) |
| 180 | + |
| 181 | + pos_mask = (gt_labels >= 0) & (gt_labels != self.num_classes) |
| 182 | + num_pos_anchors = pos_mask.sum().item() |
| 183 | + get_event_storage().put_scalar("num_pos_anchors", num_pos_anchors / num_images) |
| 184 | + normalizer = self._ema_update("loss_normalizer", max(num_pos_anchors, 1), 300) |
| 185 | + |
| 186 | + # classification and regression loss |
| 187 | + gt_labels_target = F.one_hot(gt_labels, num_classes=self.num_classes + 1)[ |
| 188 | + :, :, :-1 |
| 189 | + ] # no loss for the last (background) class |
| 190 | + loss_cls = sigmoid_focal_loss_jit( |
| 191 | + torch.cat(pred_logits, dim=1), |
| 192 | + gt_labels_target.to(pred_logits[0].dtype), |
| 193 | + alpha=self.focal_loss_alpha, |
| 194 | + gamma=self.focal_loss_gamma, |
| 195 | + reduction="sum", |
| 196 | + ) |
| 197 | + |
| 198 | + loss_box_reg = _dense_box_regression_loss( |
| 199 | + anchors, |
| 200 | + self.box2box_transform, |
| 201 | + pred_anchor_deltas, |
| 202 | + [x.tensor for x in gt_boxes], |
| 203 | + pos_mask, |
| 204 | + box_reg_loss_type="giou", |
| 205 | + ) |
| 206 | + |
| 207 | + ctrness_targets = self.compute_ctrness_targets(anchors, gt_boxes) # NxR |
| 208 | + pred_centerness = torch.cat(pred_centerness, dim=1).squeeze(dim=2) # NxR |
| 209 | + ctrness_loss = F.binary_cross_entropy_with_logits( |
| 210 | + pred_centerness[pos_mask], ctrness_targets[pos_mask], reduction="sum" |
| 211 | + ) |
| 212 | + return { |
| 213 | + "loss_fcos_cls": loss_cls / normalizer, |
| 214 | + "loss_fcos_loc": loss_box_reg / normalizer, |
| 215 | + "loss_fcos_ctr": ctrness_loss / normalizer, |
| 216 | + } |
| 217 | + |
| 218 | + def compute_ctrness_targets(self, anchors, gt_boxes): # NxR |
| 219 | + anchors = Boxes.cat(anchors).tensor # Rx4 |
| 220 | + reg_targets = [self.box2box_transform.get_deltas(anchors, m.tensor) for m in gt_boxes] |
| 221 | + reg_targets = torch.stack(reg_targets, dim=0) # NxRx4 |
| 222 | + if len(reg_targets) == 0: |
| 223 | + return reg_targets.new_zeros(len(reg_targets)) |
| 224 | + left_right = reg_targets[:, :, [0, 2]] |
| 225 | + top_bottom = reg_targets[:, :, [1, 3]] |
| 226 | + ctrness = (left_right.min(dim=-1)[0] / left_right.max(dim=-1)[0]) * ( |
| 227 | + top_bottom.min(dim=-1)[0] / top_bottom.max(dim=-1)[0] |
| 228 | + ) |
| 229 | + return torch.sqrt(ctrness) |
| 230 | + |
| 231 | + def forward_inference( |
| 232 | + self, images: ImageList, features: List[Tensor], predictions: List[List[Tensor]] |
| 233 | + ): |
| 234 | + pred_logits, pred_anchor_deltas, pred_centerness = self._transpose_dense_predictions( |
| 235 | + predictions, [self.num_classes, 4, 1] |
| 236 | + ) |
| 237 | + anchors = self.anchor_generator(features) |
| 238 | + |
| 239 | + results: List[Instances] = [] |
| 240 | + for img_idx, image_size in enumerate(images.image_sizes): |
| 241 | + scores_per_image = [ |
| 242 | + # Multiply and sqrt centerness & classification scores |
| 243 | + # (See eqn. 4 in https://arxiv.org/abs/2006.09214) |
| 244 | + torch.sqrt(x[img_idx].sigmoid_() * y[img_idx].sigmoid_()) |
| 245 | + for x, y in zip(pred_logits, pred_centerness) |
| 246 | + ] |
| 247 | + deltas_per_image = [x[img_idx] for x in pred_anchor_deltas] |
| 248 | + results_per_image = self.inference_single_image( |
| 249 | + anchors, scores_per_image, deltas_per_image, image_size |
| 250 | + ) |
| 251 | + results.append(results_per_image) |
| 252 | + return results |
| 253 | + |
| 254 | + def inference_single_image( |
| 255 | + self, |
| 256 | + anchors: List[Boxes], |
| 257 | + box_cls: List[Tensor], |
| 258 | + box_delta: List[Tensor], |
| 259 | + image_size: Tuple[int, int], |
| 260 | + ): |
| 261 | + """ |
| 262 | + Identical to :meth:`RetinaNet.inference_single_image. |
| 263 | + """ |
| 264 | + pred = self._decode_multi_level_predictions( |
| 265 | + anchors, |
| 266 | + box_cls, |
| 267 | + box_delta, |
| 268 | + self.test_score_thresh, |
| 269 | + self.test_topk_candidates, |
| 270 | + image_size, |
| 271 | + ) |
| 272 | + keep = batched_nms( |
| 273 | + pred.pred_boxes.tensor, pred.scores, pred.pred_classes, self.test_nms_thresh |
| 274 | + ) |
| 275 | + return pred[keep[: self.max_detections_per_image]] |
| 276 | + |
| 277 | + |
| 278 | +class FCOSHead(RetinaNetHead): |
| 279 | + """ |
| 280 | + The head used in :paper:`fcos`. It adds an additional centerness |
| 281 | + prediction branch on top of :class:`RetinaNetHead`. |
| 282 | + """ |
| 283 | + |
| 284 | + def __init__(self, *, conv_dims: List[int], **kwargs): |
| 285 | + super().__init__(conv_dims=conv_dims, num_anchors=1, **kwargs) |
| 286 | + # Unlike original FCOS, we do not add an additional learnable scale layer |
| 287 | + # because it's found to have no benefits after normalizing regression targets by stride. |
| 288 | + self.ctrness = nn.Conv2d(conv_dims[-1], 1, kernel_size=3, stride=1, padding=1) |
| 289 | + torch.nn.init.normal_(self.ctrness.weight, std=0.01) |
| 290 | + torch.nn.init.constant_(self.ctrness.bias, 0) |
| 291 | + |
| 292 | + def forward(self, features): |
| 293 | + logits = [] |
| 294 | + bbox_reg = [] |
| 295 | + ctrness = [] |
| 296 | + for feature in features: |
| 297 | + logits.append(self.cls_score(self.cls_subnet(feature))) |
| 298 | + bbox_feature = self.bbox_subnet(feature) |
| 299 | + bbox_reg.append(self.bbox_pred(bbox_feature)) |
| 300 | + ctrness.append(self.ctrness(bbox_feature)) |
| 301 | + return logits, bbox_reg, ctrness |
0 commit comments