Skip to content

Commit e5d30db

Browse files
authored
[飞桨论文复现挑战赛(第六期)] (101) Unsupervised Learning of Visual Features by Contrasting Cluster Assignments (#120)
* add-swav * add swav linear probe * fix bug with axis * update pretrain config * add tipc * add doc
1 parent 6e6da5a commit e5d30db

File tree

19 files changed

+831
-4
lines changed

19 files changed

+831
-4
lines changed

README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,8 @@ PASSL implements a series of self-supervised learning algorithms, See **Document
4646
| MoCo-BYOL | 300 | 71.56 | 72.10 | ResNet-50 | [download](https://passl.bj.bcebos.com/models/mocobyol_r50_ep300_ckpt.pdparams) | [Train MoCo-BYOL](docs/Train_MoCo-BYOL_model.md) |
4747
| BYOL | 300 | 72.50 | 71.62 | ResNet-50 | [download](https://passl.bj.bcebos.com/models/byol_r50_300.pdparams) | [Train BYOL](docs/Train_BYOL_model.md) |
4848
| PixPro | 100 | 55.1(fp16) | 57.2(fp32) | ResNet-50 | [download](https://passl.bj.bcebos.com/models/pixpro_r50_ep100_no_instance_with_linear.pdparams) | [Train PixPro](docs/Train_PixPro_model.md) |
49-
| SimSiam | 100 | 68.3 | 68.4 | ResNet-50 | [download](https://drive.google.com/file/d/1kaAm8-tlvB570kzI4fo9h4dwGQFf_4FE/view?usp=sharing) | [Train SimSiam](docs/Train_SimSiam_model.md) |
49+
| SimSiam | 100 | 68.3 | 68.4 | ResNet-50 | [download](https://drive.google.com/file/d/1kaAm8-tlvB570kzI4fo9h4dwGQFf_4FE/view?usp=sharing) | [Train SimSiam](docs/Train_SimSiam_model.md) |
50+
| SwAV | 100 | 72.1 | 72.4 | ResNet-50 | [download](https://drive.google.com/file/d/1budFSoQqZz1Idyej-R4E6kUnL8CGtdyu/view?usp=sharing) | [Train SwAV](docs/Train_SwAV_model.md) |
5051

5152
> Benchmark Linear Image Classification on ImageNet-1K.
5253

README_cn.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,8 @@ PASSL 实现了一系列自监督学习算法,更具体的使用文档请参
4646
| MoCo-BYOL | 300 | 71.56 | 72.10 | ResNet-50 | [download](https://passl.bj.bcebos.com/models/mocobyol_r50_ep300_ckpt.pdparams) | [Train MoCo-BYOL](docs/Train_MoCo-BYOL_model.md) |
4747
| BYOL | 300 | 72.50 | 71.62 | ResNet-50 | [download](https://passl.bj.bcebos.com/models/byol_r50_300.pdparams) | [Train BYOL](docs/Train_BYOL_model.md) |
4848
| PixPro | 100 | 55.1(fp16) | 57.2(fp32) | ResNet-50 | [download](https://passl.bj.bcebos.com/models/pixpro_r50_ep100_no_instance_with_linear.pdparams) | [Train PixPro](docs/Train_PixPro_model.md) |
49-
| SimSiam | 100 | 68.3 | 68.4 | ResNet-50 | [download](https://drive.google.com/file/d/1kaAm8-tlvB570kzI4fo9h4dwGQFf_4FE/view?usp=sharing) | [Train SimSiam](docs/Train_SimSiam_model.md) |
49+
| SimSiam | 100 | 68.3 | 68.4 | ResNet-50 | [download](https://drive.google.com/file/d/1kaAm8-tlvB570kzI4fo9h4dwGQFf_4FE/view?usp=sharing) | [Train SimSiam](docs/Train_SimSiam_model.md) |
50+
| SwAV | 100 | 72.1 | 72.4 | ResNet-50 | [download](https://drive.google.com/file/d/1budFSoQqZz1Idyej-R4E6kUnL8CGtdyu/view?usp=sharing) | [Train SwAV](docs/Train_SwAV_model.md) |
5051

5152
> Benchmark Linear Image Classification on ImageNet-1K.
5253

configs/swav/swav_clas_r50.yaml

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
epochs: 100
2+
output_dir: output_dir
3+
seed: 0
4+
device: gpu
5+
6+
# used for static mode and model export
7+
image_shape: [3, 224, 224]
8+
save_inference_dir: ./inference
9+
10+
model:
11+
name: Classification
12+
backbone:
13+
name: ResNetswav
14+
depth: 50
15+
frozen_stages: 4
16+
head:
17+
name: ClasHead
18+
with_avg_pool: true
19+
in_channels: 2048
20+
21+
dataloader:
22+
train:
23+
loader:
24+
num_workers: 8
25+
use_shared_memory: True
26+
sampler:
27+
batch_size: 32
28+
shuffle: true
29+
drop_last: true
30+
dataset:
31+
name: ImageNet
32+
dataroot: data/ILSVRC2012/train
33+
return_label: True
34+
transforms:
35+
- name: RandomResizedCrop
36+
size: 224
37+
- name: RandomHorizontalFlip
38+
- name: Transpose
39+
- name: NormalizeImage
40+
scale: 1.0/255.0
41+
mean: [0.485, 0.456, 0.406]
42+
std: [0.229, 0.224, 0.225]
43+
val:
44+
loader:
45+
num_workers: 8
46+
use_shared_memory: True
47+
sampler:
48+
batch_size: 32
49+
shuffle: false
50+
drop_last: false
51+
dataset:
52+
name: ImageNet
53+
dataroot: data/ILSVRC2012/val
54+
return_label: True
55+
transforms:
56+
- name: Resize
57+
size: 256
58+
- name: CenterCrop
59+
size: 224
60+
- name: Transpose
61+
- name: NormalizeImage
62+
scale: 1.0/255.0
63+
mean: [0.485, 0.456, 0.406]
64+
std: [0.229, 0.224, 0.225]
65+
66+
lr_scheduler:
67+
name: CosineAnnealingDecay
68+
learning_rate: 0.3
69+
T_max: 100
70+
71+
72+
optimizer:
73+
name: Momentum
74+
weight_decay: 1e-6
75+
76+
log_config:
77+
name: LogHook
78+
interval: 50
79+
80+
custom_config:
81+
- name: EvaluateHook

configs/swav/swav_r50_100ep.yaml

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
epochs: 100
2+
output_dir: output_dir
3+
seed: 0
4+
device: gpu
5+
6+
model:
7+
name: SwAV
8+
backbone:
9+
name: ResNetswav
10+
depth: 50
11+
neck:
12+
name: SwAVNeck
13+
in_channels: 2048
14+
hid_channels: 2048
15+
out_channels: 128
16+
with_l2norm: True
17+
with_avg_pool: True
18+
head:
19+
name: SwAVHead
20+
feat_dim: 128
21+
sinkhorn_iterations: 3
22+
epsilon: 0.05
23+
temperature: 0.1
24+
crops_for_assign: [0, 1]
25+
num_crops: [2, 6]
26+
num_prototypes: 3000
27+
28+
dataloader:
29+
train:
30+
loader:
31+
num_workers: 16
32+
use_shared_memory: True
33+
sampler:
34+
batch_size: 128
35+
shuffle: true
36+
drop_last: true
37+
dataset:
38+
name: MultiCropDataset
39+
dataroot: data/ILSVRC2012/train
40+
size_crops: [224, 96]
41+
num_crops: [2, 6]
42+
min_scale_crops: [0.14, 0.05]
43+
max_scale_crops: [1., 0.14]
44+
45+
lr_scheduler:
46+
name: CosineWarmup
47+
learning_rate: 4.8
48+
T_max: 31200
49+
warmup_steps: 3120
50+
start_lr: 0.3
51+
end_lr: 4.8
52+
eta_min: 0.0048
53+
54+
optimizer:
55+
name: LarsMomentumOptimizer
56+
momentum: 0.9
57+
lars_weight_decay: 1e-6
58+
59+
optimizer_config:
60+
name: SwAVOptimizerHook
61+
freeze_prototypes_iters: 313
62+
63+
log_config:
64+
name: LogHook
65+
interval: 50

docs/Train_SwAV_model.md

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
# Train SwAV Model
2+
3+
## Introduction
4+
5+
PASSL reproduces [SwAV](https://arxiv.org/abs/2006.09882). SwAV is an online algorithm that takes advantage of contrastive methods without requiring to compute pairwise comparisons. Compared to previous contrastive methods, SwAV is more memory efficient since it does not require a large memory bank or a special momentum network
6+
7+
## Installation
8+
- See [INSTALL.md](INSTALL.md)
9+
10+
## Data Preparation
11+
- See [GETTING_STARTED.md](GETTING_STARTED.md)
12+
13+
## Implemented Models
14+
Models are all trained with ResNet-50 backbone.
15+
| | epochs |official results | passl results | Backbone| Model |
16+
| ---|--- | ---- | ---- | ----| ---- |
17+
| SwAV | 100 | 72.1 | 72.4 | ResNet-50 | [download](https://drive.google.com/file/d/1budFSoQqZz1Idyej-R4E6kUnL8CGtdyu/view?usp=sharing) |
18+
19+
20+
## Getting Started
21+
22+
### 1. Train SwAV
23+
24+
#### single gpu
25+
```
26+
python tools/train.py -c configs/swav/swav_r50_100ep.yaml
27+
```
28+
29+
#### multiple gpus
30+
31+
```
32+
python -m paddle.distributed.launch --gpus="0,1,2,3,4,5,6,7" tools/train.py -c configs/swav/swav_r50_100ep.yaml
33+
```
34+
35+
Pretraining models with 100 epochs can be found at [swav](https://drive.google.com/file/d/1budFSoQqZz1Idyej-R4E6kUnL8CGtdyu/view?usp=sharing)
36+
37+
Note: The default learning rate in config files is for 8 GPUs. If using differnt number GPUs, the total batch size will change in proportion, you have to scale the learning rate following ```new_lr = old_lr * new_ngpus / old_ngpus```.
38+
39+
### 2. Extract backbone weights
40+
41+
```
42+
python tools/extract_weight.py ${CHECKPOINT} --output ${WEIGHT_FILE} --remove_prefix
43+
```
44+
45+
### 3. Evaluation on ImageNet Linear Classification
46+
47+
#### Train:
48+
```
49+
python -m paddle.distributed.launch --gpus="0,1,2,3,4,5,6,7" tools/train.py -c configs/swav/swav_clas_r50.yaml --pretrained ${WEIGHT_FILE}
50+
```
51+
52+
#### Evaluate:
53+
```
54+
python -m paddle.distributed.launch --gpus="0,1,2,3,4,5,6,7" tools/train.py -c configs/swav/swav_clas_r50.yaml --load ${CLS_WEGHT_FILE} --evaluate-only
55+
```
56+
57+
The trained linear weights in conjuction with the backbone weights can be found at [swav linear](https://drive.google.com/file/d/1uduDAqJqK1uFclhQSK0d9RjzGNYR_Tj2/view?usp=sharing)

passl/datasets/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,4 +19,5 @@
1919
from .cifar import CIFAR10, CIFAR100
2020

2121
from .textimagedataset import TextImageDataset
22+
from .multicropdataset import MultiCropDataset
2223
from .builder import build_dataset, build_dataloader

passl/datasets/multicropdataset.py

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from paddle.vision.transforms import (
16+
Compose,
17+
Transpose,
18+
ColorJitter,
19+
RandomResizedCrop,
20+
RandomHorizontalFlip,
21+
)
22+
from .folder import DatasetFolder
23+
from .builder import DATASETS
24+
from .preprocess.transforms import (
25+
RandomApply,
26+
GaussianBlur,
27+
NormalizeImage,
28+
RandomGrayscale,
29+
)
30+
31+
32+
@DATASETS.register()
33+
class MultiCropDataset(DatasetFolder):
34+
cls_filter = None
35+
36+
def __init__(self,
37+
dataroot,
38+
size_crops,
39+
num_crops,
40+
min_scale_crops,
41+
max_scale_crops,
42+
return_label=False):
43+
super(MultiCropDataset, self).__init__(dataroot, cls_filter=self.cls_filter)
44+
45+
assert len(size_crops) == len(num_crops)
46+
assert len(min_scale_crops) == len(num_crops)
47+
assert len(max_scale_crops) == len(num_crops)
48+
self.return_label = return_label
49+
50+
color_transform = [get_color_distortion(), get_pil_gaussian_blur()]
51+
mean = [0.485, 0.456, 0.406]
52+
std = [0.229, 0.224, 0.225]
53+
trans = []
54+
for i in range(len(size_crops)):
55+
randomresizedcrop = RandomResizedCrop(
56+
size_crops[i],
57+
scale=(min_scale_crops[i], max_scale_crops[i]),
58+
)
59+
trans.extend([Compose([
60+
randomresizedcrop,
61+
RandomHorizontalFlip(prob=0.5),
62+
Compose(color_transform),
63+
Transpose(),
64+
NormalizeImage(scale='1.0/255.0', mean=mean, std=std)])
65+
] * num_crops[i])
66+
self.trans = trans
67+
68+
def __getitem__(self, index):
69+
"""
70+
Args:
71+
index (int): Index
72+
73+
Returns:
74+
tuple: (sample, target) where target is class_index of the target class.
75+
"""
76+
path, target = self.samples[index]
77+
sample = self.loader(path)
78+
sample = list(map(lambda trans: trans(sample), self.trans))
79+
if self.return_label:
80+
return sample, target
81+
82+
return sample
83+
84+
85+
86+
def get_pil_gaussian_blur(p=0.5):
87+
gaussian_blur = GaussianBlur(sigma=[.1, 2.], _PIL=True)
88+
rnd_gaussian_blur = RandomApply([gaussian_blur], p=p)
89+
return rnd_gaussian_blur
90+
91+
92+
def get_color_distortion(s=1.0):
93+
# s is the strength of color distortion.
94+
color_jitter = ColorJitter(0.8*s, 0.8*s, 0.8*s, 0.2*s)
95+
rnd_color_jitter = RandomApply([color_jitter], p=0.8)
96+
rnd_gray = RandomGrayscale(p=0.2)
97+
color_distort = Compose([rnd_color_jitter, rnd_gray])
98+
return color_distort

passl/hooks/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from .hook import Hook
1717
from .lr_scheduler_hook import LRSchedulerHook
1818
from .optimizer_hook import OptimizerHook
19+
from .optimizer_hook import SwAVOptimizerHook
1920
from .timer_hook import IterTimerHook
2021
from .log_hook import LogHook
2122
from .checkpoint_hook import CheckpointHook

0 commit comments

Comments
 (0)