Skip to content

Commit efd8f4d

Browse files
ppwwyyxxfacebook-github-bot
authored andcommitted
generalized dense detector base
Reviewed By: zhanghang1989, wat3rBro Differential Revision: D30721780 fbshipit-source-id: eeefca099c4eb158d59f697f4e4a1419dc69a806
1 parent 0e76e35 commit efd8f4d

File tree

7 files changed

+355
-249
lines changed

7 files changed

+355
-249
lines changed

detectron2/export/caffe2_modeling.py

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88

99
from detectron2.modeling import meta_arch
1010
from detectron2.modeling.box_regression import Box2BoxTransform
11-
from detectron2.modeling.meta_arch.retinanet import permute_to_N_HWA_K
1211
from detectron2.modeling.roi_heads import keypoint_head
1312
from detectron2.structures import Boxes, ImageList, Instances, RotatedBoxes
1413

@@ -369,14 +368,28 @@ def get_outputs_converter(predict_net, init_net):
369368
)
370369

371370
# hack to reuse inference code from RetinaNet
372-
self.inference = functools.partial(meta_arch.RetinaNet.inference, self)
373-
self.inference_single_image = functools.partial(
374-
meta_arch.RetinaNet.inference_single_image, self
375-
)
371+
for meth in [
372+
"forward_inference",
373+
"inference_single_image",
374+
"_transpose_dense_predictions",
375+
"_decode_multi_level_predictions",
376+
"_decode_per_level_predictions",
377+
]:
378+
setattr(self, meth, functools.partial(getattr(meta_arch.RetinaNet, meth), self))
376379

377380
def f(batched_inputs, c2_inputs, c2_results):
378381
_, im_info = c2_inputs
379382
image_sizes = [[int(im[0]), int(im[1])] for im in im_info]
383+
dummy_images = ImageList(
384+
torch.randn(
385+
(
386+
len(im_info),
387+
3,
388+
)
389+
+ tuple(image_sizes[0])
390+
),
391+
image_sizes,
392+
)
380393

381394
num_features = len([x for x in c2_results.keys() if x.startswith("box_cls_")])
382395
pred_logits = [c2_results["box_cls_{}".format(i)] for i in range(num_features)]
@@ -385,15 +398,12 @@ def f(batched_inputs, c2_inputs, c2_results):
385398
# For each feature level, feature should have the same batch size and
386399
# spatial dimension as the box_cls and box_delta.
387400
dummy_features = [x.clone()[:, 0:0, :, :] for x in pred_logits]
388-
anchors = self.anchor_generator(dummy_features)
389-
390401
# self.num_classess can be inferred
391402
self.num_classes = pred_logits[0].shape[1] // (pred_anchor_deltas[0].shape[1] // 4)
392403

393-
pred_logits = [permute_to_N_HWA_K(x, self.num_classes) for x in pred_logits]
394-
pred_anchor_deltas = [permute_to_N_HWA_K(x, 4) for x in pred_anchor_deltas]
395-
396-
results = self.inference(anchors, pred_logits, pred_anchor_deltas, image_sizes)
404+
results = self.forward_inference(
405+
dummy_images, dummy_features, [pred_logits, pred_anchor_deltas]
406+
)
397407
return meta_arch.GeneralizedRCNN._postprocess(results, batched_inputs, image_sizes)
398408

399409
return f

detectron2/modeling/matcher.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from detectron2.layers import nonzero_tuple
66

77

8+
# TODO: the name is too general
89
class Matcher(object):
910
"""
1011
This class assigns to each predicted "element" (e.g., a box) a ground-truth

detectron2/modeling/meta_arch/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
# import all the meta_arch, so they will be registered
99
from .rcnn import GeneralizedRCNN, ProposalNetwork
10+
from .dense_detector import DenseDetector
1011
from .retinanet import RetinaNet
1112
from .semantic_seg import SEM_SEG_HEADS_REGISTRY, SemanticSegmentor, build_sem_seg_head
1213

Lines changed: 284 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,284 @@
1+
import numpy as np
2+
from typing import Dict, List, Optional, Tuple
3+
import torch
4+
from torch import Tensor, nn
5+
6+
from detectron2.data.detection_utils import convert_image_to_rgb
7+
from detectron2.modeling import Backbone
8+
from detectron2.structures import Boxes, ImageList, Instances
9+
from detectron2.utils.events import get_event_storage
10+
11+
from ..postprocessing import detector_postprocess
12+
13+
14+
def permute_to_N_HWA_K(tensor, K: int):
15+
"""
16+
Transpose/reshape a tensor from (N, (Ai x K), H, W) to (N, (HxWxAi), K)
17+
"""
18+
assert tensor.dim() == 4, tensor.shape
19+
N, _, H, W = tensor.shape
20+
tensor = tensor.view(N, -1, K, H, W)
21+
tensor = tensor.permute(0, 3, 4, 1, 2)
22+
tensor = tensor.reshape(N, -1, K) # Size=(N,HWA,K)
23+
return tensor
24+
25+
26+
class DenseDetector(nn.Module):
27+
"""
28+
Base class for dense detector. We define a dense detector as a fully-convolutional model that
29+
makes per-pixel (i.e. dense) predictions.
30+
"""
31+
32+
def __init__(
33+
self,
34+
backbone: Backbone,
35+
head: nn.Module,
36+
head_in_features: Optional[List[str]] = None,
37+
*,
38+
pixel_mean,
39+
pixel_std,
40+
):
41+
"""
42+
Args:
43+
backbone: backbone module
44+
head: head module
45+
head_in_features: backbone features to use in head. Default to all backbone features.
46+
pixel_mean (Tuple[float]):
47+
Values to be used for image normalization (BGR order).
48+
To train on images of different number of channels, set different mean & std.
49+
Default values are the mean pixel value from ImageNet: [103.53, 116.28, 123.675]
50+
pixel_std (Tuple[float]):
51+
When using pre-trained models in Detectron1 or any MSRA models,
52+
std has been absorbed into its conv1 weights, so the std needs to be set 1.
53+
Otherwise, you can use [57.375, 57.120, 58.395] (ImageNet std)
54+
"""
55+
super().__init__()
56+
57+
self.backbone = backbone
58+
self.head = head
59+
if head_in_features is None:
60+
shapes = self.backbone.output_shape()
61+
self.head_in_features = sorted(shapes.keys(), key=lambda x: shapes[x].stride)
62+
else:
63+
self.head_in_features = head_in_features
64+
65+
self.register_buffer("pixel_mean", torch.tensor(pixel_mean).view(-1, 1, 1), False)
66+
self.register_buffer("pixel_std", torch.tensor(pixel_std).view(-1, 1, 1), False)
67+
68+
@property
69+
def device(self):
70+
return self.pixel_mean.device
71+
72+
def forward(self, batched_inputs: List[Dict[str, Tensor]]):
73+
"""
74+
Args:
75+
batched_inputs: a list, batched outputs of :class:`DatasetMapper` .
76+
Each item in the list contains the inputs for one image.
77+
For now, each item in the list is a dict that contains:
78+
79+
* image: Tensor, image in (C, H, W) format.
80+
* instances: Instances
81+
82+
Other information that's included in the original dicts, such as:
83+
84+
* "height", "width" (int): the output resolution of the model, used in inference.
85+
See :meth:`postprocess` for details.
86+
87+
Returns:
88+
In training, dict[str, Tensor]: mapping from a named loss to a tensor storing the
89+
loss. Used during training only. In inference, the standard output format, described
90+
in :doc:`/tutorials/models`.
91+
"""
92+
images = self.preprocess_image(batched_inputs)
93+
features = self.backbone(images.tensor)
94+
features = [features[f] for f in self.head_in_features]
95+
predictions = self.head(features)
96+
97+
if self.training:
98+
assert not torch.jit.is_scripting(), "Not supported"
99+
assert "instances" in batched_inputs[0], "Instance annotations are missing in training!"
100+
gt_instances = [x["instances"].to(self.device) for x in batched_inputs]
101+
return self.forward_training(images, features, predictions, gt_instances)
102+
else:
103+
results = self.forward_inference(images, features, predictions)
104+
if torch.jit.is_scripting():
105+
return results
106+
107+
processed_results = []
108+
for results_per_image, input_per_image, image_size in zip(
109+
results, batched_inputs, images.image_sizes
110+
):
111+
height = input_per_image.get("height", image_size[0])
112+
width = input_per_image.get("width", image_size[1])
113+
r = detector_postprocess(results_per_image, height, width)
114+
processed_results.append({"instances": r})
115+
return processed_results
116+
117+
def forward_training(self, images, features, predictions, gt_instances):
118+
raise NotImplementedError()
119+
120+
def preprocess_image(self, batched_inputs: List[Dict[str, Tensor]]):
121+
"""
122+
Normalize, pad and batch the input images.
123+
"""
124+
images = [x["image"].to(self.device) for x in batched_inputs]
125+
images = [(x - self.pixel_mean) / self.pixel_std for x in images]
126+
images = ImageList.from_tensors(images, self.backbone.size_divisibility)
127+
return images
128+
129+
def _transpose_dense_predictions(
130+
self, predictions: List[List[Tensor]], dims_per_anchor: List[int]
131+
) -> List[List[Tensor]]:
132+
"""
133+
Transpose the dense per-level predictions.
134+
135+
Args:
136+
predictions: a list of outputs, each is a list of per-level
137+
predictions with shape (N, Ai x K, Hi, Wi), where N is the
138+
number of images, Ai is the number of anchors per location on
139+
level i, K is the dimension of predictions per anchor.
140+
dims_per_anchor: the value of K for each predictions. e.g. 4 for
141+
box prediction, #classes for classification prediction.
142+
143+
Returns:
144+
List[List[Tensor]]: each prediction is transposed to (N, Hi x Wi x Ai, K).
145+
"""
146+
assert len(predictions) == len(dims_per_anchor)
147+
res: List[List[Tensor]] = []
148+
for pred, dim_per_anchor in zip(predictions, dims_per_anchor):
149+
pred = [permute_to_N_HWA_K(x, dim_per_anchor) for x in pred]
150+
res.append(pred)
151+
return res
152+
153+
def _ema_update(self, name: str, value: float, initial_value: float, momentum: float = 0.9):
154+
"""
155+
Apply EMA update to `self.name` using `value`.
156+
157+
This is mainly used for loss normalizer. In Detectron1, loss is normalized by number
158+
of foreground samples in the batch. When batch size is 1 per GPU, #foreground has a
159+
large variance and using it lead to lower performance. Therefore we maintain an EMA of
160+
#foreground to stabilize the normalizer.
161+
162+
Args:
163+
name: name of the normalizer
164+
value: the new value to update
165+
initial_value: the initial value to start with
166+
momentum: momentum of EMA
167+
168+
Returns:
169+
float: the updated EMA value
170+
"""
171+
if hasattr(self, name):
172+
old = getattr(self, name)
173+
else:
174+
old = initial_value
175+
new = old * momentum + value * (1 - momentum)
176+
setattr(self, name, new)
177+
return new
178+
179+
def _decode_per_level_predictions(
180+
self,
181+
anchors: Boxes,
182+
pred_scores: Tensor,
183+
pred_deltas: Tensor,
184+
score_thresh: float,
185+
topk_candidates: int,
186+
image_size: Tuple[int, int],
187+
):
188+
"""
189+
Decode boxes and classification predictions of one featuer level, by
190+
the following steps:
191+
1. filter the predictions based on score threshold and top K scores.
192+
2. transform the box regression outputs
193+
3. return the predicted scores, classes and boxes
194+
195+
Args:
196+
anchors: Boxes, anchor for this feature level
197+
pred_scores: HxWxA,K
198+
pred_deltas: HxWxA,4
199+
200+
Returns:
201+
Instances: with field "scores", "pred_boxes", "pred_classes".
202+
"""
203+
# Apply two filtering to make NMS faster.
204+
# 1. Keep boxes with confidence score higher than threshold
205+
keep_idxs = pred_scores > score_thresh
206+
pred_scores = pred_scores[keep_idxs]
207+
topk_idxs = torch.nonzero(keep_idxs) # Kx2
208+
209+
# 2. Keep top k top scoring boxes only
210+
num_topk = min(topk_candidates, topk_idxs.size(0))
211+
# torch.sort is actually faster than .topk (at least on GPUs)
212+
pred_scores, idxs = pred_scores.sort(descending=True)
213+
pred_scores = pred_scores[:num_topk]
214+
topk_idxs = topk_idxs[idxs[:num_topk]]
215+
216+
anchor_idxs, classes_idxs = topk_idxs.unbind(dim=1)
217+
218+
pred_boxes = self.box2box_transform.apply_deltas(
219+
pred_deltas[anchor_idxs], anchors.tensor[anchor_idxs]
220+
)
221+
return Instances(
222+
image_size, pred_boxes=Boxes(pred_boxes), scores=pred_scores, pred_classes=classes_idxs
223+
)
224+
225+
def _decode_multi_level_predictions(
226+
self,
227+
anchors: List[Boxes],
228+
pred_scores: List[Tensor],
229+
pred_deltas: List[Tensor],
230+
score_thresh: float,
231+
topk_candidates: int,
232+
image_size: Tuple[int, int],
233+
):
234+
"""
235+
Run `_decode_per_level_predictions` for all feature levels and concat the results.
236+
"""
237+
predictions = [
238+
self._decode_per_level_predictions(
239+
anchors_i,
240+
box_cls_i,
241+
box_reg_i,
242+
self.test_score_thresh,
243+
self.test_topk_candidates,
244+
image_size,
245+
)
246+
# Iterate over every feature level
247+
for box_cls_i, box_reg_i, anchors_i in zip(pred_scores, pred_deltas, anchors)
248+
]
249+
return predictions[0].cat(predictions) # 'Instances.cat' is not scriptale but this is
250+
251+
def visualize_training(self, batched_inputs, results):
252+
"""
253+
A function used to visualize ground truth images and final network predictions.
254+
It shows ground truth bounding boxes on the original image and up to 20
255+
predicted object bounding boxes on the original image.
256+
257+
Args:
258+
batched_inputs (list): a list that contains input to the model.
259+
results (List[Instances]): a list of #images elements returned by forward_inference().
260+
"""
261+
from detectron2.utils.visualizer import Visualizer
262+
263+
assert len(batched_inputs) == len(
264+
results
265+
), "Cannot visualize inputs and results of different sizes"
266+
storage = get_event_storage()
267+
max_boxes = 20
268+
269+
image_index = 0 # only visualize a single image
270+
img = batched_inputs[image_index]["image"]
271+
img = convert_image_to_rgb(img.permute(1, 2, 0), self.input_format)
272+
v_gt = Visualizer(img, None)
273+
v_gt = v_gt.overlay_instances(boxes=batched_inputs[image_index]["instances"].gt_boxes)
274+
anno_img = v_gt.get_image()
275+
processed_results = detector_postprocess(results[image_index], img.shape[0], img.shape[1])
276+
predicted_boxes = processed_results.pred_boxes.tensor.detach().cpu().numpy()
277+
278+
v_pred = Visualizer(img, None)
279+
v_pred = v_pred.overlay_instances(boxes=predicted_boxes[0:max_boxes])
280+
prop_img = v_pred.get_image()
281+
vis_img = np.vstack((anno_img, prop_img))
282+
vis_img = vis_img.transpose(2, 0, 1)
283+
vis_name = f"Top: GT bounding boxes; Bottom: {max_boxes} Highest Scoring Results"
284+
storage.put_image(vis_name, vis_img)

0 commit comments

Comments
 (0)