Skip to content

Commit 4aaaf4d

Browse files
[Feature] Support Class Aware Sampler (#7436)
* [Feature] Support Class Aware Sampler * minor fix * minor fix * rename get_label_dict to get_index_dict * fix cas logic * minor fix * minor fix * minor fix * minor fix * minor fix
1 parent 24f2fdb commit 4aaaf4d

File tree

11 files changed

+320
-31
lines changed

11 files changed

+320
-31
lines changed

configs/openimages/README.md

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
# Open Images Dataset
2-
<!-- [DATASET] -->
32

3+
> [Open Images Dataset](https://arxiv.org/abs/1811.00982)
4+
5+
<!-- [DATASET] -->
46
## Abstract
57

68
<!-- [ABSTRACT] -->
@@ -90,14 +92,14 @@ training/testing by using `tools/misc/get_image_metas.py`.
9092
│ │ │ ├── class-descriptions-boxable.csv
9193
│ │ │ ├── oidv6-train-annotations-bbox.scv
9294
│ │ │ ├── validation-annotations-bbox.csv
93-
│ │ │ ├── validation-annotations-human-imagelabels-boxable.csv # is not necessary
95+
│ │ │ ├── validation-annotations-human-imagelabels-boxable.csv
9496
│ │ │ ├── validation-image-metas.pkl # get from script
9597
│ │ ├── challenge2019
9698
│ │ │ ├── challenge-2019-train-detection-bbox.txt
9799
│ │ │ ├── challenge-2019-validation-detection-bbox.txt
98100
│ │ │ ├── class_label_tree.np
99101
│ │ │ ├── class_sample_train.pkl
100-
│ │ │ ├── challenge-2019-validation-detection-human-imagelabels.csv # download from official website, not necessary
102+
│ │ │ ├── challenge-2019-validation-detection-human-imagelabels.csv # download from official website
101103
│ │ │ ├── challenge-2019-validation-metas.pkl # get from script
102104
│ │ ├── OpenImages
103105
│ │ │ ├── train # training images
@@ -112,14 +114,30 @@ Open Images v6, but the test images are different.
112114
You can also download the annotations from [official website](https://storage.googleapis.com/openimages/web/challenge2019_downloads.html),
113115
and set data.train.type=OpenImagesDataset, data.val.type=OpenImagesDataset, and data.test.type=OpenImagesDataset in the config
114116
3. If users do not want to use `validation-annotations-human-imagelabels-boxable.csv` and `challenge-2019-validation-detection-human-imagelabels.csv`
115-
users can should set `data.val.load_image_level_labels=False` and `data.test.load_image_level_labels=False` in the config .
116-
117+
users can set `data.val.load_image_level_labels=False` and `data.test.load_image_level_labels=False` in the config.
118+
Please note that loading image-levels label is the default of Open Images evaluation metric.
119+
More details please refer to the [official website](https://storage.googleapis.com/openimages/web/evaluation.html)
117120
118121
## Results and Models
119122
120123
| Architecture | Backbone | Style | Lr schd | Sampler | Mem (GB) | Inf time (fps) | box AP | Config | Download |
121124
|:------------:|:---------:|:-------:|:-------:|:-------:|:--------:|:--------------:|:------:|:------:|:--------:|
122125
| Faster R-CNN | R-50 | pytorch | 1x | Group Sampler | 7.7 | - | 51.6 |[config](https://github.com/open-mmlab/mmdetection/tree/master/configs/openimages/faster_rcnn_r50_fpn_32x2_1x_openimages.py) | [model](https://download.openmmlab.com/mmdetection/v2.0/openimages/faster_rcnn_r50_fpn_32x2_1x_openimages/faster_rcnn_r50_fpn_32x2_1x_openimages_20211130_231159-e87ab7ce.pth) &#124; [log](https://download.openmmlab.com/mmdetection/v2.0/openimages/faster_rcnn_r50_fpn_32x2_1x_openimages/faster_rcnn_r50_fpn_32x2_1x_openimages_20211130_231159.log.json) |
123-
| Faster R-CNN (Challenge 2019) | R-50 | pytorch | 1x | Group Sampler | 7.7 | - | 54.5 |[config](https://github.com/open-mmlab/mmdetection/tree/master/configs/openimages/faster_rcnn_r50_fpn_32x2_1x_openimages_challenge.py) | [model](https://download.openmmlab.com/mmdetection/v2.0/openimages/faster_rcnn_r50_fpn_32x2_1x_openimages_challenge/faster_rcnn_r50_fpn_32x2_1x_openimages_challenge_20211229_071252-46380cde.pth) &#124; [log](https://download.openmmlab.com/mmdetection/v2.0/openimages/faster_rcnn_r50_fpn_32x2_1x_openimages_challenge/faster_rcnn_r50_fpn_32x2_1x_openimages_challenge_20211229_071252.log.json) |
126+
| Faster R-CNN | R-50 | pytorch | 1x | Class Aware Sampler | 7.7 | - | 60.0 |[config](https://github.com/open-mmlab/mmdetection/tree/master/configs/openimages/faster_rcnn_r50_fpn_32x2_cas_1x_openimages.py) | [model](https://download.openmmlab.com/mmdetection/v2.0/openimages/faster_rcnn_r50_fpn_32x2_cas_1x_openimages/faster_rcnn_r50_fpn_32x2_cas_1x_openimages_20220306_202424-98c630e5.pth) &#124; [log](https://download.openmmlab.com/mmdetection/v2.0/openimages/faster_rcnn_r50_fpn_32x2_1x_openimages/faster_rcnn_r50_fpn_32x2_cas_1x_openimages_20220306_202424.log.json) |
127+
| Faster R-CNN (Challenge 2019) | R-50 | pytorch | 1x | Group Sampler | 7.7 | - | 54.9 |[config](https://github.com/open-mmlab/mmdetection/tree/master/configs/openimages/faster_rcnn_r50_fpn_32x2_1x_openimages_challenge.py) | [model](https://download.openmmlab.com/mmdetection/v2.0/openimages/faster_rcnn_r50_fpn_32x2_1x_openimages_challenge/faster_rcnn_r50_fpn_32x2_1x_openimages_challenge_20220114_045100-0e79e5df.pth) &#124; [log](https://download.openmmlab.com/mmdetection/v2.0/openimages/faster_rcnn_r50_fpn_32x2_1x_openimages_challenge/faster_rcnn_r50_fpn_32x2_1x_openimages_challenge_20220114_045100.log.json) |
128+
| Faster R-CNN (Challenge 2019) | R-50 | pytorch | 1x | Class Aware Sampler | 7.1 | - | 65.0 |[config](https://github.com/open-mmlab/mmdetection/tree/master/configs/openimages/faster_rcnn_r50_fpn_32x2_cas_1x_openimages_challenge.py) | [model](https://download.openmmlab.com/mmdetection/v2.0/openimages/faster_rcnn_r50_fpn_32x2_cas_1x_openimages_challenge/faster_rcnn_r50_fpn_32x2_cas_1x_openimages_challenge_20220221_192021-34c402d9.pth) &#124; [log](https://download.openmmlab.com/mmdetection/v2.0/openimages/faster_rcnn_r50_fpn_32x2_cas_1x_openimages_challenge/faster_rcnn_r50_fpn_32x2_cas_1x_openimages_challenge_20220221_192021.log.json) |
124129
| Retinanet | R-50 | pytorch | 1x | Group Sampler | 6.6 | - | 61.5 |[config](https://github.com/open-mmlab/mmdetection/tree/master/configs/openimages/retinanet_r50_fpn_32x2_1x_openimages.py) | [model](https://download.openmmlab.com/mmdetection/v2.0/openimages/retinanet_r50_fpn_32x2_1x_openimages/retinanet_r50_fpn_32x2_1x_openimages_20211223_071954-d2ae5462.pth) &#124; [log](https://download.openmmlab.com/mmdetection/v2.0/openimages/retinanet_r50_fpn_32x2_1x_openimages/retinanet_r50_fpn_32x2_1x_openimages_20211223_071954.log.json) |
125-
| SSD | VGG16 | pytorch | 36e | Group Sampler | 10.8 | - | 35.4 |[config](https://github.com/open-mmlab/mmdetection/tree/master/configs/openimages/ssd300_32x8_36e_openimages.py) | [model](https://download.openmmlab.com/mmdetection/v2.0/openimages/ssd300_32x8_36e_openimages/ssd300_32x8_36e_openimages_20211224_000232-dce93846.pth) &#124; [log](ttps://download.openmmlab.com/mmdetection/v2.0/openimages/ssd300_32x8_36e_openimages/ssd300_32x8_36e_openimages_20211224_000232.log.json) |
130+
| SSD | VGG16 | pytorch | 36e | Group Sampler | 10.8 | - | 35.4 |[config](https://github.com/open-mmlab/mmdetection/tree/master/configs/openimages/ssd300_32x8_36e_openimages.py) | [model](https://download.openmmlab.com/mmdetection/v2.0/openimages/ssd300_32x8_36e_openimages/ssd300_32x8_36e_openimages_20211224_000232-dce93846.pth) &#124; [log](https://download.openmmlab.com/mmdetection/v2.0/openimages/ssd300_32x8_36e_openimages/ssd300_32x8_36e_openimages_20211224_000232.log.json) |
131+
132+
**Notes:**
133+
134+
- 'cas' is short for 'Class Aware Sampler'
135+
136+
### Results of consider image level labels
137+
138+
| Architecture | Sampler | Consider Image Level Labels | box AP|
139+
|:------------:|:-------:|:---------------------------:|:-----:|
140+
|Faster R-CNN r50 (Challenge 2019)| Group Sampler| w/o | 62.19 |
141+
|Faster R-CNN r50 (Challenge 2019)| Group Sampler| w/ | 54.87 |
142+
|Faster R-CNN r50 (Challenge 2019)| Class Aware Sampler| w/o | 71.77 |
143+
|Faster R-CNN r50 (Challenge 2019)| Class Aware Sampler| w/ | 64.98 |
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
_base_ = ['faster_rcnn_r50_fpn_32x2_1x_openimages.py']
2+
3+
# Use ClassAwareSampler
4+
data = dict(
5+
train_dataloader=dict(class_aware_sampler=dict(num_sample_class=1)))
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
_base_ = ['faster_rcnn_r50_fpn_32x2_1x_openimages_challenge.py']
2+
3+
# Use ClassAwareSampler
4+
data = dict(
5+
train_dataloader=dict(class_aware_sampler=dict(num_sample_class=1)))

configs/openimages/metafile.yml

Lines changed: 58 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,14 @@
1-
Collections:
2-
- Name: Open Images Dataset
3-
Paper:
4-
URL: https://arxiv.org/abs/1811.00982
5-
Title: 'The Open Images Dataset V4: Unified image classification, object detection, and visual relationship detection at scale'
6-
README: configs/openimages/README.md
7-
Code:
8-
URL: https://github.com/open-mmlab/mmdetection/blob/v2.20.0/mmdet/datasets/openimages.py#L21
9-
Version: v2.20.0
10-
111
Models:
122
- Name: faster_rcnn_r50_fpn_32x2_1x_openimages
13-
In Collection: Open Images Dataset
3+
In Collection: Faster R-CNN
144
Config: configs/openimages/faster_rcnn_r50_fpn_32x2_1x_openimages.py
155
Metadata:
166
Training Memory (GB): 7.7
177
Epochs: 12
8+
Training Data: Open Images v6
9+
Training Techniques:
10+
- SGD with Momentum
11+
- Weight Decay
1812
Results:
1913
- Task: Object Detection
2014
Dataset: Open Images v6
@@ -23,11 +17,15 @@ Models:
2317
Weights: https://download.openmmlab.com/mmdetection/v2.0/openimages/faster_rcnn_r50_fpn_32x2_1x_openimages/faster_rcnn_r50_fpn_32x2_1x_openimages_20211130_231159-e87ab7ce.pth
2418

2519
- Name: retinanet_r50_fpn_32x2_1x_openimages
26-
In Collection: Open Images Dataset
20+
In Collection: RetinaNet
2721
Config: configs/openimages/retinanet_r50_fpn_32x2_1x_openimages.py
2822
Metadata:
2923
Training Memory (GB): 6.6
3024
Epochs: 12
25+
Training Data: Open Images v6
26+
Training Techniques:
27+
- SGD with Momentum
28+
- Weight Decay
3129
Results:
3230
- Task: Object Detection
3331
Dataset: Open Images v6
@@ -36,11 +34,15 @@ Models:
3634
Weights: https://download.openmmlab.com/mmdetection/v2.0/openimages/retinanet_r50_fpn_32x2_1x_openimages/retinanet_r50_fpn_32x2_1x_openimages_20211223_071954-d2ae5462.pth
3735

3836
- Name: ssd300_32x8_36e_openimages
39-
In Collection: Open Images Dataset
37+
In Collection: SSD
4038
Config: configs/openimages/ssd300_32x8_36e_openimages
4139
Metadata:
4240
Training Memory (GB): 10.8
43-
Epochs: 12
41+
Epochs: 36
42+
Training Data: Open Images v6
43+
Training Techniques:
44+
- SGD with Momentum
45+
- Weight Decay
4446
Results:
4547
- Task: Object Detection
4648
Dataset: Open Images v6
@@ -49,14 +51,52 @@ Models:
4951
Weights: https://download.openmmlab.com/mmdetection/v2.0/openimages/ssd300_32x8_36e_openimages/ssd300_32x8_36e_openimages_20211224_000232-dce93846.pth
5052

5153
- Name: faster_rcnn_r50_fpn_32x2_1x_openimages_challenge
52-
In Collection: Open Images Dataset
54+
In Collection: Faster R-CNN
5355
Config: configs/openimages/faster_rcnn_r50_fpn_32x2_1x_openimages_challenge.py
5456
Metadata:
5557
Training Memory (GB): 7.7
5658
Epochs: 12
59+
Training Data: Open Images Challenge 2019
60+
Training Techniques:
61+
- SGD with Momentum
62+
- Weight Decay
63+
Results:
64+
- Task: Object Detection
65+
Dataset: Open Images Challenge 2019
66+
Metrics:
67+
box AP: 54.9
68+
Weights: https://download.openmmlab.com/mmdetection/v2.0/openimages/faster_rcnn_r50_fpn_32x2_1x_openimages_challenge/faster_rcnn_r50_fpn_32x2_1x_openimages_challenge_20220114_045100-0e79e5df.pth
69+
70+
- Name: faster_rcnn_r50_fpn_32x2_cas_1x_openimages
71+
In Collection: Faster R-CNN
72+
Config: configs/openimages/faster_rcnn_r50_fpn_32x2_cas_1x_openimages.py
73+
Metadata:
74+
Training Memory (GB): 7.7
75+
Epochs: 12
76+
Training Data: Open Images Challenge 2019
77+
Training Techniques:
78+
- SGD with Momentum
79+
- Weight Decay
80+
Results:
81+
- Task: Object Detection
82+
Dataset: Open Images Challenge 2019
83+
Metrics:
84+
box AP: 60.0
85+
Weights: https://download.openmmlab.com/mmdetection/v2.0/openimages/faster_rcnn_r50_fpn_32x2_cas_1x_openimages/faster_rcnn_r50_fpn_32x2_cas_1x_openimages_20220306_202424-98c630e5.pth
86+
87+
- Name: faster_rcnn_r50_fpn_32x2_cas_1x_openimages_challenge
88+
In Collection: Faster R-CNN
89+
Config: configs/openimages/faster_rcnn_r50_fpn_32x2_cas_1x_openimages_challenge.py
90+
Metadata:
91+
Training Memory (GB): 7.1
92+
Epochs: 12
93+
Training Data: Open Images Challenge 2019
94+
Training Techniques:
95+
- SGD with Momentum
96+
- Weight Decay
5797
Results:
5898
- Task: Object Detection
59-
Dataset: Open Images Challenge 2019W
99+
Dataset: Open Images Challenge 2019
60100
Metrics:
61-
box AP: 54.5
62-
Weights: https://download.openmmlab.com/mmdetection/v2.0/openimages/faster_rcnn_r50_fpn_32x2_1x_openimages_challenge/faster_rcnn_r50_fpn_32x2_1x_openimages_challenge_20211229_071252-46380cde.pth
101+
box AP: 65.0
102+
Weights: https://download.openmmlab.com/mmdetection/v2.0/openimages/faster_rcnn_r50_fpn_32x2_cas_1x_openimages_challenge/faster_rcnn_r50_fpn_32x2_cas_1x_openimages_challenge_20220221_192021-34c402d9.pth

mmdet/datasets/builder.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@
1212
from mmcv.utils import TORCH_VERSION, Registry, build_from_cfg, digit_version
1313
from torch.utils.data import DataLoader
1414

15-
from .samplers import (DistributedGroupSampler, DistributedSampler,
16-
GroupSampler, InfiniteBatchSampler,
15+
from .samplers import (ClassAwareSampler, DistributedGroupSampler,
16+
DistributedSampler, GroupSampler, InfiniteBatchSampler,
1717
InfiniteGroupBatchSampler)
1818

1919
if platform.system() != 'Windows':
@@ -93,6 +93,7 @@ def build_dataloader(dataset,
9393
seed=None,
9494
runner_type='EpochBasedRunner',
9595
persistent_workers=False,
96+
class_aware_sampler=None,
9697
**kwargs):
9798
"""Build PyTorch DataLoader.
9899
@@ -115,6 +116,8 @@ def build_dataloader(dataset,
115116
the worker processes after a dataset has been consumed once.
116117
This allows to maintain the workers `Dataset` instances alive.
117118
This argument is only valid when PyTorch>=1.7.0. Default: False.
119+
class_aware_sampler (dict): Whether to use `ClassAwareSampler`
120+
during training. Default: None.
118121
kwargs: any keyword argument to be used to initialize DataLoader
119122
120123
Returns:
@@ -153,7 +156,18 @@ def build_dataloader(dataset,
153156
batch_size = 1
154157
sampler = None
155158
else:
156-
if dist:
159+
if class_aware_sampler is not None:
160+
# ClassAwareSampler can be used in both distributed and
161+
# non-distributed training.
162+
num_sample_class = class_aware_sampler.get('num_sample_class', 1)
163+
sampler = ClassAwareSampler(
164+
dataset,
165+
samples_per_gpu,
166+
world_size,
167+
rank,
168+
seed=seed,
169+
num_sample_class=num_sample_class)
170+
elif dist:
157171
# DistributedGroupSampler will definitely shuffle the data to
158172
# satisfy that images on each GPU are in the same group
159173
if shuffle:

mmdet/datasets/custom.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,25 @@ def get_classes(cls, classes=None):
285285

286286
return class_names
287287

288+
def get_cat2imgs(self):
289+
"""Get a dict with class as key and img_ids as values, which will be
290+
used in :class:`ClassAwareSampler`.
291+
292+
Returns:
293+
dict[list]: A dict of per-label image list,
294+
the item of the dict indicates a label index,
295+
corresponds to the image index that contains the label.
296+
"""
297+
if self.CLASSES is None:
298+
raise ValueError('self.CLASSES can not be None')
299+
# sort the label index
300+
cat2imgs = {i: [] for i in range(len(self.CLASSES))}
301+
for i in range(len(self)):
302+
cat_ids = set(self.get_cat_ids(i))
303+
for cat in cat_ids:
304+
cat2imgs[cat].append(i)
305+
return cat2imgs
306+
288307
def format_results(self, results, **kwargs):
289308
"""Place holder to format result to dataset specific output."""
290309

mmdet/datasets/openimages.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -601,6 +601,17 @@ def denormalize_gt_bboxes(self, annotations):
601601
annotations[i]['bboxes'][:, 1::2] *= h
602602
return annotations
603603

604+
def get_cat_ids(self, idx):
605+
"""Get category ids by index.
606+
607+
Args:
608+
idx (int): Index of data.
609+
610+
Returns:
611+
list[int]: All categories in the image of specified index.
612+
"""
613+
return self.get_ann_info(idx)['labels'].astype(np.int).tolist()
614+
604615
def evaluate(self,
605616
results,
606617
metric='mAP',

mmdet/datasets/pipelines/loading.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -256,12 +256,11 @@ def _load_bboxes(self, results):
256256
results['gt_bboxes'] = ann_info['bboxes'].copy()
257257

258258
if self.denorm_bbox:
259-
h, w = results['img_shape'][:2]
260259
bbox_num = results['gt_bboxes'].shape[0]
261260
if bbox_num != 0:
261+
h, w = results['img_shape'][:2]
262262
results['gt_bboxes'][:, 0::2] *= w
263263
results['gt_bboxes'][:, 1::2] *= h
264-
results['gt_bboxes'] = results['gt_bboxes'].astype(np.float32)
265264

266265
gt_bboxes_ignore = ann_info.get('bboxes_ignore', None)
267266
if gt_bboxes_ignore is not None:
Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
2+
from .class_aware_sampler import ClassAwareSampler
23
from .distributed_sampler import DistributedSampler
34
from .group_sampler import DistributedGroupSampler, GroupSampler
45
from .infinite_sampler import InfiniteBatchSampler, InfiniteGroupBatchSampler
56

67
__all__ = [
78
'DistributedSampler', 'DistributedGroupSampler', 'GroupSampler',
8-
'InfiniteGroupBatchSampler', 'InfiniteBatchSampler'
9+
'InfiniteGroupBatchSampler', 'InfiniteBatchSampler', 'ClassAwareSampler'
910
]

0 commit comments

Comments
 (0)