Skip to content

Commit 65aa210

Browse files
author
wangjianfeng
committed
feat(detection): support Objects365 and reformat
1 parent 9766a39 commit 65aa210

File tree

10 files changed

+175
-83
lines changed

10 files changed

+175
-83
lines changed

official/vision/detection/layers/basic/functional.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,7 @@
1010
import megengine.functional as F
1111
import numpy as np
1212

13-
from megengine import _internal as mgb
14-
from megengine.core import Tensor, wrap_io_tensor
13+
from megengine.core import Tensor
1514

1615

1716
def get_padded_array_np(
@@ -86,8 +85,3 @@ def get_padded_tensor(
8685
else:
8786
raise Exception("Not supported tensor dim: %d" % ndim)
8887
return padded_array
89-
90-
91-
@wrap_io_tensor
92-
def indexing_set_one_hot(inp, axis, idx, value) -> Tensor:
93-
return mgb.opr.indexing_set_one_hot(inp, axis, idx, value)

official/vision/detection/layers/det/loss.py

Lines changed: 18 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,6 @@
1212

1313
from megengine.core import tensor, Tensor
1414

15-
from official.vision.detection.layers import basic
16-
1715

1816
def get_focal_loss(
1917
score: Tensor,
@@ -51,28 +49,19 @@ def get_focal_loss(
5149
Returns:
5250
the calculated focal loss.
5351
"""
54-
mask = 1 - (label == ignore_label)
55-
valid_label = label * mask
56-
57-
score_shp = score.shape
58-
zero_mat = mge.zeros(
59-
F.concat([score_shp[0], score_shp[1], score_shp[2] + 1], axis=0),
60-
dtype=np.float32,
61-
)
62-
one_mat = mge.ones(
63-
F.concat([score_shp[0], score_shp[1], tensor(1)], axis=0), dtype=np.float32,
64-
)
65-
66-
one_hot = basic.indexing_set_one_hot(
67-
zero_mat, 2, valid_label.astype(np.int32), one_mat
68-
)[:, :, 1:]
69-
pos_part = F.power(1 - score, gamma) * one_hot * F.log(score)
70-
neg_part = F.power(score, gamma) * (1 - one_hot) * F.log(1 - score)
71-
loss = -(alpha * pos_part + (1 - alpha) * neg_part).sum(axis=2) * mask
52+
class_range = F.arange(1, score.shape[2] + 1)
53+
54+
label = F.add_axis(label, axis=2)
55+
pos_part = (1 - score) ** gamma * F.log(score)
56+
neg_part = score ** gamma * F.log(1 - score)
57+
58+
pos_loss = -(label == class_range) * pos_part * alpha
59+
neg_loss = -(label != class_range) * (label != ignore_label) * neg_part * (1 - alpha)
60+
loss = pos_loss + neg_loss
7261

7362
if norm_type == "fg":
74-
positive_mask = label > background
75-
return loss.sum() / F.maximum(positive_mask.sum(), 1)
63+
fg_mask = (label != background) * (label != ignore_label)
64+
return loss.sum() / F.maximum(fg_mask.sum(), 1)
7665
elif norm_type == "none":
7766
return loss.sum()
7867
else:
@@ -117,8 +106,7 @@ def get_smooth_l1_loss(
117106
gt_bbox = gt_bbox.reshape(-1, 4)
118107
label = label.reshape(-1)
119108

120-
valid_mask = 1 - (label == ignore_label)
121-
fg_mask = (1 - (label == background)) * valid_mask
109+
fg_mask = (label != background) * (label != ignore_label)
122110

123111
losses = get_smooth_l1_base(pred_bbox, gt_bbox, sigma, is_fix=fix_smooth_l1)
124112
if norm_type == "fg":
@@ -154,19 +142,16 @@ def get_smooth_l1_base(
154142
cond_point = sigma
155143
x = pred_bbox - gt_bbox
156144
abs_x = F.abs(x)
157-
in_mask = abs_x < cond_point
158-
out_mask = 1 - in_mask
159-
in_loss = 0.5 * (x ** 2)
160-
out_loss = sigma * abs_x - 0.5 * (sigma ** 2)
161-
loss = in_loss * in_mask + out_loss * out_mask
145+
in_loss = 0.5 * x ** 2
146+
out_loss = sigma * abs_x - 0.5 * sigma ** 2
162147
else:
163148
sigma2 = sigma ** 2
164149
cond_point = 1 / sigma2
165150
x = pred_bbox - gt_bbox
166151
abs_x = F.abs(x)
167-
in_mask = abs_x < cond_point
168-
out_mask = 1 - in_mask
169-
in_loss = 0.5 * (sigma * x) ** 2
152+
in_loss = 0.5 * x ** 2 * sigma2
170153
out_loss = abs_x - 0.5 / sigma2
171-
loss = in_loss * in_mask + out_loss * out_mask
154+
in_mask = abs_x < cond_point
155+
out_mask = 1 - in_mask
156+
loss = in_loss * in_mask + out_loss * out_mask
172157
return loss

official/vision/detection/layers/det/retinanet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def __init__(self, cfg, input_shape: List[basic.ShapeSpec]):
2828
num_classes = cfg.num_classes
2929
num_convs = 4
3030
prior_prob = cfg.cls_prior_prob
31-
num_anchors = [9, 9, 9, 9, 9]
31+
num_anchors = [len(cfg.anchor_ratios) * len(cfg.anchor_scales)] * 5
3232

3333
assert (
3434
len(set(num_anchors)) == 1
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
# -*- coding: utf-8 -*-
2+
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
3+
#
4+
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
5+
#
6+
# Unless required by applicable law or agreed to in writing,
7+
# software distributed under the License is distributed on an
8+
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
from .retinanet import *
10+
11+
_EXCLUDE = {}
12+
__all__ = [k for k in globals().keys() if k not in _EXCLUDE and not k.startswith("_")]

official/vision/detection/retinanet_res50_1x_800size.py renamed to official/vision/detection/models/retinanet.py

Lines changed: 16 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
import megengine.functional as F
1111
import megengine.module as M
1212
import numpy as np
13-
from megengine import hub
1413

1514
from official.vision.classification.resnet.model import resnet50
1615
from official.vision.detection import layers
@@ -47,7 +46,7 @@ def __init__(self, cfg, batch_size):
4746
for p in bottom_up.layer1.parameters():
4847
p.requires_grad = False
4948

50-
# -------------------------- build the FPN -------------------------- #
49+
# ----------------------- build the FPN ----------------------------- #
5150
in_channels_p6p7 = 2048
5251
out_channels = 256
5352
self.backbone = layers.FPN(
@@ -61,7 +60,7 @@ def __init__(self, cfg, batch_size):
6160
backbone_shape = self.backbone.output_shape()
6261
feature_shapes = [backbone_shape[f] for f in self.in_features]
6362

64-
# -------------------------- build the RetinaNet Head -------------- #
63+
# ----------------------- build the RetinaNet Head ------------------ #
6564
self.head = layers.RetinaNetHead(cfg, feature_shapes)
6665

6766
self.inputs = {
@@ -199,13 +198,22 @@ def __init__(self):
199198
self.resnet_norm = "FrozenBN"
200199
self.backbone_freeze_at = 2
201200

202-
# ------------------------ data cfg --------------------------- #
201+
# ------------------------ data cfg -------------------------- #
202+
self.train_dataset = dict(
203+
name="coco",
204+
root="train2017",
205+
ann_file="instances_train2017.json"
206+
)
207+
self.test_dataset = dict(
208+
name="coco",
209+
root="val2017",
210+
ann_file="instances_val2017.json"
211+
)
203212
self.train_image_short_size = 800
204213
self.train_image_max_size = 1333
205214
self.num_classes = 80
206215
self.img_mean = np.array([103.530, 116.280, 123.675]) # BGR
207216
self.img_std = np.array([57.375, 57.120, 58.395])
208-
# self.img_std = np.array([1.0, 1.0, 1.0])
209217
self.reg_mean = None
210218
self.reg_std = np.array([0.1, 0.1, 0.2, 0.2])
211219

@@ -217,7 +225,7 @@ def __init__(self):
217225
self.class_aware_box = False
218226
self.cls_prior_prob = 0.01
219227

220-
# ------------------------ losss cfg ------------------------- #
228+
# ------------------------ loss cfg -------------------------- #
221229
self.focal_loss_alpha = 0.25
222230
self.focal_loss_gamma = 2
223231
self.reg_loss_weight = 1.0 / 4.0
@@ -229,29 +237,14 @@ def __init__(self):
229237
self.log_interval = 20
230238
self.nr_images_epoch = 80000
231239
self.max_epoch = 18
232-
self.warm_iters = 100
240+
self.warm_iters = 500
233241
self.lr_decay_rate = 0.1
234242
self.lr_decay_sates = [12, 16, 17]
235243

236-
# ------------------------ testing cfg ------------------------- #
244+
# ------------------------ testing cfg ----------------------- #
237245
self.test_image_short_size = 800
238246
self.test_image_max_size = 1333
239247
self.test_max_boxes_per_image = 100
240248
self.test_vis_threshold = 0.3
241249
self.test_cls_threshold = 0.05
242250
self.test_nms = 0.5
243-
244-
245-
@hub.pretrained(
246-
"https://data.megengine.org.cn/models/weights/"
247-
"retinanet_d3f58dce_res50_1x_800size_36dot0.pkl"
248-
)
249-
def retinanet_res50_1x_800size(batch_size=1, **kwargs):
250-
r"""ResNet-18 model from
251-
`"RetinaNet" <https://arxiv.org/abs/1708.02002>`_
252-
"""
253-
return RetinaNet(RetinaNetConfig(), batch_size=batch_size, **kwargs)
254-
255-
256-
Net = RetinaNet
257-
Cfg = RetinaNetConfig
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
# -*- coding: utf-8 -*-
2+
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
3+
#
4+
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
5+
#
6+
# Unless required by applicable law or agreed to in writing,
7+
# software distributed under the License is distributed on an
8+
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
from megengine import hub
10+
11+
from official.vision.detection import models
12+
13+
14+
class CustomRetinaNetConfig(models.RetinaNetConfig):
15+
def __init__(self):
16+
super().__init__()
17+
18+
# ------------------------ data cfg -------------------------- #
19+
self.train_dataset = dict(
20+
name="coco",
21+
root="train2017",
22+
ann_file="annotations/instances_train2017.json"
23+
)
24+
self.test_dataset = dict(
25+
name="coco",
26+
root="val2017",
27+
ann_file="annotations/instances_val2017.json"
28+
)
29+
30+
31+
@hub.pretrained(
32+
"https://data.megengine.org.cn/models/weights/"
33+
"retinanet_d3f58dce_res50_1x_800size_36dot0.pkl"
34+
)
35+
def retinanet_res50_coco_1x_800size(batch_size=1, **kwargs):
36+
r"""ResNet-18 model from
37+
`"RetinaNet" <https://arxiv.org/abs/1708.02002>`_
38+
"""
39+
return models.RetinaNet(RetinaNetConfig(), batch_size=batch_size, **kwargs)
40+
41+
42+
Net = models.RetinaNet
43+
Cfg = CustomRetinaNetConfig
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
# -*- coding: utf-8 -*-
2+
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
3+
#
4+
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
5+
#
6+
# Unless required by applicable law or agreed to in writing,
7+
# software distributed under the License is distributed on an
8+
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
from megengine import hub
10+
11+
from official.vision.detection import models
12+
13+
14+
class CustomRetinaNetConfig(models.RetinaNetConfig):
15+
def __init__(self):
16+
super().__init__()
17+
18+
# ------------------------ data cfg -------------------------- #
19+
self.train_dataset = dict(
20+
name="objects365",
21+
root="train",
22+
ann_file="annotations/objects365_train_20190423.json"
23+
)
24+
self.test_dataset = dict(
25+
name="objects365",
26+
root="val",
27+
ann_file="annotations/objects365_val_20190423.json"
28+
)
29+
30+
# ------------------------ training cfg ---------------------- #
31+
self.nr_images_epoch = 400000
32+
33+
34+
def retinanet_objects365_res50_1x_800size(batch_size=1, **kwargs):
35+
r"""ResNet-18 model from
36+
`"RetinaNet" <https://arxiv.org/abs/1708.02002>`_
37+
"""
38+
return models.RetinaNet(RetinaNetConfig(), batch_size=batch_size, **kwargs)
39+
40+
41+
Net = models.RetinaNet
42+
Cfg = CustomRetinaNetConfig
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# -*- coding: utf-8 -*-
2+
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
3+
#
4+
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
5+
#
6+
# Unless required by applicable law or agreed to in writing,
7+
# software distributed under the License is distributed on an
8+
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
from megengine.data.dataset import COCO, Objects365
10+
11+
data_mapper = dict(
12+
coco=COCO,
13+
objects365=Objects365,
14+
)

0 commit comments

Comments
 (0)