Skip to content

Commit 31ec19b

Browse files
ppwwyyxxfacebook-github-bot
authored andcommitted
implement FCOS
Reviewed By: zhanghang1989 Differential Revision: D30754542 fbshipit-source-id: 7b35e4250e31e8a999b8b4b45b002c7fce0773ac
1 parent 0e29b7a commit 31ec19b

File tree

6 files changed

+343
-3
lines changed

6 files changed

+343
-3
lines changed
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
from ..common.optim import SGD as optimizer
2+
from ..common.coco_schedule import lr_multiplier_1x as lr_multiplier
3+
from ..common.data.coco import dataloader
4+
from ..common.models.fcos import model
5+
from ..common.train import train
6+
7+
dataloader.train.mapper.use_instance_mask = False
8+
optimizer.lr = 0.01
9+
10+
model.backbone.bottom_up.freeze_at = 2
11+
train.init_checkpoint = "detectron2://ImageNetPretrained/MSRA/R-50.pkl"

configs/common/models/fcos.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
from detectron2.modeling.meta_arch.fcos import FCOS, FCOSHead
2+
3+
from .retinanet import model
4+
5+
model._target_ = FCOS
6+
7+
del model.anchor_generator
8+
del model.box2box_transform
9+
del model.anchor_matcher
10+
del model.input_format
11+
12+
# Use P5 instead of C5 to compute P6/P7
13+
# (Sec 2.2 of https://arxiv.org/abs/2006.09214)
14+
model.backbone.top_block.in_feature = "p5"
15+
model.backbone.top_block.in_channels = 256
16+
17+
# New score threshold determined based on sqrt(cls_score * centerness)
18+
model.test_score_thresh = 0.2
19+
model.test_nms_thresh = 0.6
20+
21+
model.head._target_ = FCOSHead
22+
del model.head.num_anchors
23+
model.head.norm = "GN"

detectron2/modeling/box_regression.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from typing import List, Tuple
44
import torch
55
from fvcore.nn import giou_loss, smooth_l1_loss
6+
from torch.nn import functional as F
67

78
from detectron2.layers import cat, ciou_loss, diou_loss
89
from detectron2.structures import Boxes
@@ -244,6 +245,7 @@ def get_deltas(self, src_boxes, target_boxes):
244245
Get box regression transformation deltas (dx1, dy1, dx2, dy2) that can be used
245246
to transform the `src_boxes` into the `target_boxes`. That is, the relation
246247
``target_boxes == self.apply_deltas(deltas, src_boxes)`` is true.
248+
The center of src must be inside target boxes.
247249
248250
Args:
249251
src_boxes (Tensor): square source boxes, e.g., anchors
@@ -277,6 +279,8 @@ def apply_deltas(self, deltas, boxes):
277279
box transformations for the single box boxes[i].
278280
boxes (Tensor): boxes to transform, of shape (N, 4)
279281
"""
282+
# Ensure the output is a valid box. See Sec 2.1 of https://arxiv.org/abs/2006.09214
283+
deltas = F.relu(deltas)
280284
boxes = boxes.to(deltas.dtype)
281285

282286
ctr_x = 0.5 * (boxes[:, 0] + boxes[:, 2])

detectron2/modeling/meta_arch/dense_detector.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ def _decode_per_level_predictions(
184184
score_thresh: float,
185185
topk_candidates: int,
186186
image_size: Tuple[int, int],
187-
):
187+
) -> Instances:
188188
"""
189189
Decode boxes and classification predictions of one featuer level, by
190190
the following steps:
@@ -208,7 +208,7 @@ def _decode_per_level_predictions(
208208

209209
# 2. Keep top k top scoring boxes only
210210
num_topk = min(topk_candidates, topk_idxs.size(0))
211-
# torch.sort is actually faster than .topk (at least on GPUs)
211+
# torch.sort is actually faster than .topk (https://github.com/pytorch/pytorch/issues/22812)
212212
pred_scores, idxs = pred_scores.sort(descending=True)
213213
pred_scores = pred_scores[:num_topk]
214214
topk_idxs = topk_idxs[idxs[:num_topk]]
@@ -230,7 +230,7 @@ def _decode_multi_level_predictions(
230230
score_thresh: float,
231231
topk_candidates: int,
232232
image_size: Tuple[int, int],
233-
):
233+
) -> Instances:
234234
"""
235235
Run `_decode_per_level_predictions` for all feature levels and concat the results.
236236
"""
Lines changed: 301 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,301 @@
1+
# Copyright (c) Facebook, Inc. and its affiliates.
2+
3+
import logging
4+
from typing import List, Optional, Tuple
5+
import torch
6+
from fvcore.nn import sigmoid_focal_loss_jit
7+
from torch import Tensor, nn
8+
from torch.nn import functional as F
9+
10+
from detectron2.layers import batched_nms
11+
from detectron2.structures import Boxes, ImageList, Instances, pairwise_point_box_distance
12+
from detectron2.utils.events import get_event_storage
13+
14+
from ..anchor_generator import DefaultAnchorGenerator
15+
from ..backbone import Backbone
16+
from ..box_regression import Box2BoxTransformLinear, _dense_box_regression_loss
17+
from .dense_detector import DenseDetector
18+
from .retinanet import RetinaNetHead
19+
20+
__all__ = ["FCOS"]
21+
22+
23+
logger = logging.getLogger(__name__)
24+
25+
26+
class FCOS(DenseDetector):
27+
"""
28+
Implement FCOS in :paper:`fcos`.
29+
"""
30+
31+
def __init__(
32+
self,
33+
*,
34+
backbone: Backbone,
35+
head: nn.Module,
36+
head_in_features: Optional[List[str]] = None,
37+
box2box_transform=None,
38+
num_classes,
39+
center_sampling_radius: float = 1.5,
40+
focal_loss_alpha=0.25,
41+
focal_loss_gamma=2.0,
42+
test_score_thresh=0.2,
43+
test_topk_candidates=1000,
44+
test_nms_thresh=0.6,
45+
max_detections_per_image=100,
46+
pixel_mean,
47+
pixel_std,
48+
):
49+
"""
50+
Args:
51+
center_sampling_radius: radius of the "center" of a groundtruth box,
52+
within which all anchor points are labeled positive.
53+
Other arguments mean the same as in :class:`RetinaNet`.
54+
"""
55+
super().__init__(
56+
backbone, head, head_in_features, pixel_mean=pixel_mean, pixel_std=pixel_std
57+
)
58+
59+
self.num_classes = num_classes
60+
61+
# FCOS uses one anchor point per location.
62+
# We represent the anchor point by a box whose size equals the anchor stride.
63+
feature_shapes = backbone.output_shape()
64+
fpn_strides = [feature_shapes[k].stride for k in self.head_in_features]
65+
self.anchor_generator = DefaultAnchorGenerator(
66+
sizes=[[k] for k in fpn_strides], aspect_ratios=[1.0], strides=fpn_strides
67+
)
68+
69+
# FCOS parameterizes box regression by a linear transform,
70+
# where predictions are normalized by anchor stride (equal to anchor size).
71+
if box2box_transform is None:
72+
box2box_transform = Box2BoxTransformLinear(normalize_by_size=True)
73+
self.box2box_transform = box2box_transform
74+
75+
self.center_sampling_radius = float(center_sampling_radius)
76+
77+
# Loss parameters:
78+
self.focal_loss_alpha = focal_loss_alpha
79+
self.focal_loss_gamma = focal_loss_gamma
80+
81+
# Inference parameters:
82+
self.test_score_thresh = test_score_thresh
83+
self.test_topk_candidates = test_topk_candidates
84+
self.test_nms_thresh = test_nms_thresh
85+
self.max_detections_per_image = max_detections_per_image
86+
87+
def forward_training(self, images, features, predictions, gt_instances):
88+
# Transpose the Hi*Wi*A dimension to the middle:
89+
pred_logits, pred_anchor_deltas, pred_centerness = self._transpose_dense_predictions(
90+
predictions, [self.num_classes, 4, 1]
91+
)
92+
anchors = self.anchor_generator(features)
93+
gt_labels, gt_boxes = self.label_anchors(anchors, gt_instances)
94+
return self.losses(
95+
anchors, pred_logits, gt_labels, pred_anchor_deltas, gt_boxes, pred_centerness
96+
)
97+
98+
@torch.no_grad()
99+
def match_anchors(self, anchors: List[Boxes], gt_instances: List[Instances]):
100+
"""
101+
Match anchors with ground truth boxes.
102+
103+
Args:
104+
anchors: #level boxes, from the highest resolution to lower resolution
105+
gt_instances: ground truth instances per image
106+
107+
Returns:
108+
List[Tensor]:
109+
#image tensors, each is a vector of matched gt
110+
indices (or -1 for unmatched anchors) for all anchors.
111+
"""
112+
num_anchors_per_level = [len(x) for x in anchors]
113+
anchors = Boxes.cat(anchors) # Rx4
114+
anchor_centers = anchors.get_centers() # Rx2
115+
anchor_sizes = anchors.tensor[:, 2] - anchors.tensor[:, 0] # R
116+
117+
lower_bound = anchor_sizes * 4
118+
lower_bound[: num_anchors_per_level[0]] = 0
119+
upper_bound = anchor_sizes * 8
120+
upper_bound[-num_anchors_per_level[-1] :] = float("inf")
121+
122+
matched_indices = []
123+
for gt_per_image in gt_instances:
124+
gt_centers = gt_per_image.gt_boxes.get_centers() # Nx2
125+
# FCOS with center sampling: anchor point must be close enough to gt center.
126+
pairwise_match = (anchor_centers[:, None, :] - gt_centers[None, :, :]).abs_().max(
127+
dim=2
128+
).values < self.center_sampling_radius * anchor_sizes[:, None]
129+
pairwise_dist = pairwise_point_box_distance(anchor_centers, gt_per_image.gt_boxes)
130+
131+
# The original FCOS anchor matching rule: anchor point must be inside gt
132+
pairwise_match &= pairwise_dist.min(dim=2).values > 0
133+
134+
# Multilevel anchor matching in FCOS: each anchor is only responsible
135+
# for certain scale range.
136+
pairwise_dist = pairwise_dist.max(dim=2).values
137+
pairwise_match &= (pairwise_dist > lower_bound[:, None]) & (
138+
pairwise_dist < upper_bound[:, None]
139+
)
140+
141+
# Match the GT box with minimum area, if there are multiple GT matches
142+
gt_areas = gt_per_image.gt_boxes.area() # N
143+
pairwise_match = pairwise_match.to(torch.float32) * (1e8 - gt_areas[None, :])
144+
min_values, matched_idx = pairwise_match.max(dim=1) # R, per-anchor match
145+
matched_idx[min_values < 1e-5] = -1 # Unmatched anchors are assigned -1
146+
147+
matched_indices.append(matched_idx)
148+
return matched_indices
149+
150+
@torch.no_grad()
151+
def label_anchors(self, anchors, gt_instances):
152+
"""
153+
Same interface as :meth:`RetinaNet.label_anchors`, but implemented with FCOS
154+
anchor matching rule.
155+
156+
Unlike RetinaNet, there are no ignored anchors.
157+
"""
158+
matched_indices = self.match_anchors(anchors, gt_instances)
159+
160+
matched_labels, matched_boxes = [], []
161+
for gt_index, gt_per_image in zip(matched_indices, gt_instances):
162+
label = gt_per_image.gt_classes[gt_index.clip(min=0)]
163+
label[gt_index < 0] = self.num_classes # background
164+
165+
matched_gt_boxes = gt_per_image.gt_boxes[gt_index.clip(min=0)]
166+
167+
matched_labels.append(label)
168+
matched_boxes.append(matched_gt_boxes)
169+
return matched_labels, matched_boxes
170+
171+
def losses(
172+
self, anchors, pred_logits, gt_labels, pred_anchor_deltas, gt_boxes, pred_centerness
173+
):
174+
"""
175+
This method is almost identical to :meth:`RetinaNet.losses`, with an extra
176+
"loss_centerness" in the returned dict.
177+
"""
178+
num_images = len(gt_labels)
179+
gt_labels = torch.stack(gt_labels) # (N, R)
180+
181+
pos_mask = (gt_labels >= 0) & (gt_labels != self.num_classes)
182+
num_pos_anchors = pos_mask.sum().item()
183+
get_event_storage().put_scalar("num_pos_anchors", num_pos_anchors / num_images)
184+
normalizer = self._ema_update("loss_normalizer", max(num_pos_anchors, 1), 300)
185+
186+
# classification and regression loss
187+
gt_labels_target = F.one_hot(gt_labels, num_classes=self.num_classes + 1)[
188+
:, :, :-1
189+
] # no loss for the last (background) class
190+
loss_cls = sigmoid_focal_loss_jit(
191+
torch.cat(pred_logits, dim=1),
192+
gt_labels_target.to(pred_logits[0].dtype),
193+
alpha=self.focal_loss_alpha,
194+
gamma=self.focal_loss_gamma,
195+
reduction="sum",
196+
)
197+
198+
loss_box_reg = _dense_box_regression_loss(
199+
anchors,
200+
self.box2box_transform,
201+
pred_anchor_deltas,
202+
[x.tensor for x in gt_boxes],
203+
pos_mask,
204+
box_reg_loss_type="giou",
205+
)
206+
207+
ctrness_targets = self.compute_ctrness_targets(anchors, gt_boxes) # NxR
208+
pred_centerness = torch.cat(pred_centerness, dim=1).squeeze(dim=2) # NxR
209+
ctrness_loss = F.binary_cross_entropy_with_logits(
210+
pred_centerness[pos_mask], ctrness_targets[pos_mask], reduction="sum"
211+
)
212+
return {
213+
"loss_fcos_cls": loss_cls / normalizer,
214+
"loss_fcos_loc": loss_box_reg / normalizer,
215+
"loss_fcos_ctr": ctrness_loss / normalizer,
216+
}
217+
218+
def compute_ctrness_targets(self, anchors, gt_boxes): # NxR
219+
anchors = Boxes.cat(anchors).tensor # Rx4
220+
reg_targets = [self.box2box_transform.get_deltas(anchors, m.tensor) for m in gt_boxes]
221+
reg_targets = torch.stack(reg_targets, dim=0) # NxRx4
222+
if len(reg_targets) == 0:
223+
return reg_targets.new_zeros(len(reg_targets))
224+
left_right = reg_targets[:, :, [0, 2]]
225+
top_bottom = reg_targets[:, :, [1, 3]]
226+
ctrness = (left_right.min(dim=-1)[0] / left_right.max(dim=-1)[0]) * (
227+
top_bottom.min(dim=-1)[0] / top_bottom.max(dim=-1)[0]
228+
)
229+
return torch.sqrt(ctrness)
230+
231+
def forward_inference(
232+
self, images: ImageList, features: List[Tensor], predictions: List[List[Tensor]]
233+
):
234+
pred_logits, pred_anchor_deltas, pred_centerness = self._transpose_dense_predictions(
235+
predictions, [self.num_classes, 4, 1]
236+
)
237+
anchors = self.anchor_generator(features)
238+
239+
results: List[Instances] = []
240+
for img_idx, image_size in enumerate(images.image_sizes):
241+
scores_per_image = [
242+
# Multiply and sqrt centerness & classification scores
243+
# (See eqn. 4 in https://arxiv.org/abs/2006.09214)
244+
torch.sqrt(x[img_idx].sigmoid_() * y[img_idx].sigmoid_())
245+
for x, y in zip(pred_logits, pred_centerness)
246+
]
247+
deltas_per_image = [x[img_idx] for x in pred_anchor_deltas]
248+
results_per_image = self.inference_single_image(
249+
anchors, scores_per_image, deltas_per_image, image_size
250+
)
251+
results.append(results_per_image)
252+
return results
253+
254+
def inference_single_image(
255+
self,
256+
anchors: List[Boxes],
257+
box_cls: List[Tensor],
258+
box_delta: List[Tensor],
259+
image_size: Tuple[int, int],
260+
):
261+
"""
262+
Identical to :meth:`RetinaNet.inference_single_image.
263+
"""
264+
pred = self._decode_multi_level_predictions(
265+
anchors,
266+
box_cls,
267+
box_delta,
268+
self.test_score_thresh,
269+
self.test_topk_candidates,
270+
image_size,
271+
)
272+
keep = batched_nms(
273+
pred.pred_boxes.tensor, pred.scores, pred.pred_classes, self.test_nms_thresh
274+
)
275+
return pred[keep[: self.max_detections_per_image]]
276+
277+
278+
class FCOSHead(RetinaNetHead):
279+
"""
280+
The head used in :paper:`fcos`. It adds an additional centerness
281+
prediction branch on top of :class:`RetinaNetHead`.
282+
"""
283+
284+
def __init__(self, *, conv_dims: List[int], **kwargs):
285+
super().__init__(conv_dims=conv_dims, num_anchors=1, **kwargs)
286+
# Unlike original FCOS, we do not add an additional learnable scale layer
287+
# because it's found to have no benefits after normalizing regression targets by stride.
288+
self.ctrness = nn.Conv2d(conv_dims[-1], 1, kernel_size=3, stride=1, padding=1)
289+
torch.nn.init.normal_(self.ctrness.weight, std=0.01)
290+
torch.nn.init.constant_(self.ctrness.bias, 0)
291+
292+
def forward(self, features):
293+
logits = []
294+
bbox_reg = []
295+
ctrness = []
296+
for feature in features:
297+
logits.append(self.cls_score(self.cls_subnet(feature)))
298+
bbox_feature = self.bbox_subnet(feature)
299+
bbox_reg.append(self.bbox_pred(bbox_feature))
300+
ctrness.append(self.ctrness(bbox_feature))
301+
return logits, bbox_reg, ctrness

docs/conf.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -333,6 +333,7 @@ def autodoc_skip_member(app, what, name, obj, skip, options):
333333
),
334334
"dds": ("2003.13678", "Designing Network Design Spaces"),
335335
"scaling": ("2103.06877", "Fast and Accurate Model Scaling"),
336+
"fcos": ("2006.09214", "FCOS: A Simple and Strong Anchor-free Object Detector"),
336337
}
337338

338339

0 commit comments

Comments
 (0)