Skip to content

Commit 6e6da5a

Browse files
authored
[飞桨论文复现挑战赛(第六期)] (99) Exploring Simple Siamese Representation Learning (#117)
* add simsiam * add no_bias arg in nonlinearneckv2 * add tipc * update tipc * update simsiam optimize builder * add lars apx support
1 parent 0f432f1 commit 6e6da5a

File tree

17 files changed

+543
-9
lines changed

17 files changed

+543
-9
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ PASSL implements a series of self-supervised learning algorithms, See **Document
4646
| MoCo-BYOL | 300 | 71.56 | 72.10 | ResNet-50 | [download](https://passl.bj.bcebos.com/models/mocobyol_r50_ep300_ckpt.pdparams) | [Train MoCo-BYOL](docs/Train_MoCo-BYOL_model.md) |
4747
| BYOL | 300 | 72.50 | 71.62 | ResNet-50 | [download](https://passl.bj.bcebos.com/models/byol_r50_300.pdparams) | [Train BYOL](docs/Train_BYOL_model.md) |
4848
| PixPro | 100 | 55.1(fp16) | 57.2(fp32) | ResNet-50 | [download](https://passl.bj.bcebos.com/models/pixpro_r50_ep100_no_instance_with_linear.pdparams) | [Train PixPro](docs/Train_PixPro_model.md) |
49+
| SimSiam | 100 | 68.3 | 68.4 | ResNet-50 | [download](https://drive.google.com/file/d/1kaAm8-tlvB570kzI4fo9h4dwGQFf_4FE/view?usp=sharing) | [Train SimSiam](docs/Train_SimSiam_model.md) |
4950

5051
> Benchmark Linear Image Classification on ImageNet-1K.
5152

README_cn.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ PASSL 实现了一系列自监督学习算法,更具体的使用文档请参
4646
| MoCo-BYOL | 300 | 71.56 | 72.10 | ResNet-50 | [download](https://passl.bj.bcebos.com/models/mocobyol_r50_ep300_ckpt.pdparams) | [Train MoCo-BYOL](docs/Train_MoCo-BYOL_model.md) |
4747
| BYOL | 300 | 72.50 | 71.62 | ResNet-50 | [download](https://passl.bj.bcebos.com/models/byol_r50_300.pdparams) | [Train BYOL](docs/Train_BYOL_model.md) |
4848
| PixPro | 100 | 55.1(fp16) | 57.2(fp32) | ResNet-50 | [download](https://passl.bj.bcebos.com/models/pixpro_r50_ep100_no_instance_with_linear.pdparams) | [Train PixPro](docs/Train_PixPro_model.md) |
49+
| SimSiam | 100 | 68.3 | 68.4 | ResNet-50 | [download](https://drive.google.com/file/d/1kaAm8-tlvB570kzI4fo9h4dwGQFf_4FE/view?usp=sharing) | [Train SimSiam](docs/Train_SimSiam_model.md) |
4950

5051
> Benchmark Linear Image Classification on ImageNet-1K.
5152
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
epochs: 90
2+
use_simclr_iters: True
3+
global_batch_size: 4096
4+
output_dir: output_dir
5+
seed: 0
6+
device: gpu
7+
8+
# used for static mode and model export
9+
image_shape: [3, 224, 224]
10+
save_inference_dir: ./inference
11+
12+
model:
13+
name: Classification
14+
backbone:
15+
name: ResNet
16+
depth: 50
17+
frozen_stages: 4
18+
head:
19+
name: ClasHead
20+
with_avg_pool: true
21+
in_channels: 2048
22+
23+
dataloader:
24+
train:
25+
loader:
26+
num_workers: 16
27+
use_shared_memory: True
28+
sampler:
29+
batch_size: 512
30+
shuffle: true
31+
drop_last: true
32+
dataset:
33+
name: ImageNet
34+
dataroot: data/ILSVRC2012/train
35+
return_label: True
36+
transforms:
37+
- name: RandomResizedCrop
38+
size: 224
39+
- name: RandomHorizontalFlip
40+
- name: Transpose
41+
- name: NormalizeImage
42+
scale: 1.0/255.0
43+
mean: [0.485, 0.456, 0.406]
44+
std: [0.229, 0.224, 0.225]
45+
val:
46+
loader:
47+
num_workers: 16
48+
use_shared_memory: True
49+
sampler:
50+
batch_size: 512
51+
shuffle: false
52+
drop_last: false
53+
dataset:
54+
name: ImageNet
55+
dataroot: data/ILSVRC2012/val
56+
return_label: True
57+
transforms:
58+
- name: Resize
59+
size: 256
60+
- name: CenterCrop
61+
size: 224
62+
- name: Transpose
63+
- name: NormalizeImage
64+
scale: 1.0/255.0
65+
mean: [0.485, 0.456, 0.406]
66+
std: [0.229, 0.224, 0.225]
67+
68+
lr_scheduler:
69+
name: Cosinesimclr
70+
learning_rate: 1.6
71+
T_max: 90
72+
73+
optimizer:
74+
name: LarsMomentumOptimizer
75+
momentum: 0.9
76+
lars_weight_decay: 0.0
77+
78+
log_config:
79+
name: LogHook
80+
interval: 50
81+
82+
lr_config:
83+
name: LRSchedulerHook
84+
unit: epoch
85+
86+
custom_config:
87+
- name: EvaluateHook

configs/simsiam/simsiam_r50.yaml

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
epochs: 100
2+
use_simclr_iters: True
3+
global_batch_size: 512
4+
output_dir: output_dir
5+
seed: 0
6+
device: gpu
7+
8+
model:
9+
name: SimSiam
10+
backbone:
11+
name: ResNet
12+
depth: 50
13+
with_pool: True
14+
num_classes: 2048
15+
zero_init_residual: True
16+
predictor:
17+
name: NonLinearNeckV2
18+
in_channels: 2048
19+
hid_channels: 512
20+
out_channels: 2048
21+
with_bias: False
22+
with_avg_pool: False
23+
head:
24+
name: SimSiamContrastiveHead
25+
26+
dataloader:
27+
train:
28+
loader:
29+
num_workers: 16
30+
use_shared_memory: True
31+
sampler:
32+
batch_size: 64
33+
shuffle: true
34+
drop_last: true
35+
dataset:
36+
name: ImageNet
37+
dataroot: data/ILSVRC2012/train
38+
return_label: False
39+
return_two_sample: True
40+
transforms:
41+
- name: RandomResizedCrop
42+
size: 224
43+
scale: [0.2, 1.]
44+
view_trans1:
45+
- name: RandomApply
46+
transforms:
47+
- name: ColorJitter
48+
brightness: 0.4
49+
contrast: 0.4
50+
saturation: 0.4
51+
hue: 0.1
52+
p: 0.8
53+
- name: RandomGrayscale
54+
p: 0.2
55+
- name: RandomApply
56+
transforms:
57+
- name: GaussianBlur
58+
sigma: [0.1, 2.0]
59+
p: 0.5
60+
- name: RandomHorizontalFlip
61+
- name: Transpose
62+
- name: NormalizeImage
63+
scale: 1.0/255.0
64+
mean: [0.485, 0.456, 0.406]
65+
std: [0.229, 0.224, 0.225]
66+
view_trans2:
67+
- name: RandomApply
68+
transforms:
69+
- name: ColorJitter
70+
brightness: 0.4
71+
contrast: 0.4
72+
saturation: 0.4
73+
hue: 0.1
74+
p: 0.8
75+
- name: RandomGrayscale
76+
p: 0.2
77+
- name: RandomApply
78+
transforms:
79+
- name: GaussianBlur
80+
sigma: [0.1, 2.0]
81+
p: 0.5
82+
- name: RandomHorizontalFlip
83+
- name: Transpose
84+
- name: NormalizeImage
85+
scale: 1.0/255.0
86+
mean: [0.485, 0.456, 0.406]
87+
std: [0.229, 0.224, 0.225]
88+
89+
lr_scheduler:
90+
name: Cosinesimclr
91+
learning_rate: 0.1
92+
T_max: 100
93+
94+
optimizer:
95+
name: Momentum
96+
weight_decay: 0.0001
97+
98+
optimizer_config:
99+
name: SimsiamOptimizerHook
100+
101+
log_config:
102+
name: LogHook
103+
interval: 50
104+
105+
lr_config:
106+
name: LRSchedulerHook
107+
unit: epoch

docs/Train_SimSiam_model.md

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
# Train SimSiam Model
2+
3+
## Introduction
4+
5+
PASSL reproduces [SimSiam](https://arxiv.org/abs/2011.10566), which is a simsiam network for unsupervised visual representation learning.
6+
7+
## Installation
8+
- See [INSTALL.md](INSTALL.md)
9+
10+
## Data Preparation
11+
- See [GETTING_STARTED.md](GETTING_STARTED.md)
12+
13+
## Implemented Models
14+
Models are all trained with ResNet-50 backbone.
15+
| | epochs |official results | passl results | Backbone| Model |
16+
| ---|--- | ---- | ---- | ----| ---- |
17+
| SimSiam | 100 | 68.3 | 68.4 | ResNet-50 | [download](https://drive.google.com/file/d/1kaAm8-tlvB570kzI4fo9h4dwGQFf_4FE/view?usp=sharing) |
18+
19+
20+
## Getting Started
21+
22+
### 1. Train SimSiam
23+
24+
#### single gpu
25+
```
26+
python tools/train.py -c configs/simsiam/simsiam_r50.yaml
27+
```
28+
29+
#### multiple gpus
30+
31+
```
32+
python -m paddle.distributed.launch --gpus="0,1,2,3,4,5,6,7" tools/train.py -c configs/simsiam/simsiam_r50.yaml
33+
```
34+
35+
Pretraining models with 100 epochs can be found at [simsiam](https://drive.google.com/file/d/1kaAm8-tlvB570kzI4fo9h4dwGQFf_4FE/view?usp=sharing)
36+
37+
Note: The default learning rate in config files is for 8 GPUs. If using differnt number GPUs, the total batch size will change in proportion, you have to scale the learning rate following ```new_lr = old_lr * new_ngpus / old_ngpus```.
38+
39+
### 2. Extract backbone weights
40+
41+
```
42+
python tools/extract_weight.py ${CHECKPOINT} --output ${WEIGHT_FILE} --prefix encoder --remove_prefix
43+
```
44+
45+
### 3. Evaluation on ImageNet Linear Classification
46+
47+
#### Train:
48+
```
49+
python -m paddle.distributed.launch --gpus="0,1,2,3,4,5,6,7" tools/train.py -c configs/simsiam/simsiam_clas_r50.yaml --pretrained ${WEIGHT_FILE}
50+
```
51+
52+
#### Evaluate:
53+
```
54+
python -m paddle.distributed.launch --gpus="0,1,2,3,4,5,6,7" tools/train.py -c configs/simsiam/simsiam_clas_r50.yaml --load ${CLS_WEGHT_FILE} --evaluate-only
55+
```
56+
57+
The trained linear weights in conjuction with the backbone weights can be found at [simsiam linear](https://drive.google.com/file/d/19smHZGhBEPWeyLjKIGhM7KPngr-8BOUl/view?usp=sharing)

passl/engine/trainer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,7 @@ def __init__(self, cfg):
135135

136136
n_parameters = sum(p.numel() for p in self.model.parameters()
137137
if not p.stop_gradient).item()
138+
138139
i = int(math.log(n_parameters, 10) // 3)
139140
size_unit = ['', 'K', 'M', 'B', 'T', 'Q']
140141
self.logger.info("Number of Parameters is {:.2f}{}.".format(
@@ -163,6 +164,7 @@ def __init__(self, cfg):
163164
else:
164165
self.lr_scheduler = build_lr_scheduler(cfg.lr_scheduler,
165166
self.iters_per_epoch)
167+
166168
self.optimizer = build_optimizer(cfg.optimizer, self.lr_scheduler,
167169
[self.model])
168170

@@ -439,7 +441,6 @@ def load(self, weight_path, export=False):
439441
for k, v in state_dict.items():
440442
state_dict_['model.' + k] = v
441443
state_dict = state_dict_
442-
443444
self.model.set_state_dict(state_dict)
444445

445446
def export(self, ckpt):

passl/hooks/optimizer_hook.py

Lines changed: 56 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
from .hook import Hook
1616
from .builder import HOOKS
17+
from ..solver.builder import build_optimizer
1718

1819

1920
@HOOKS.register()
@@ -33,15 +34,68 @@ def train_iter_end(self, trainer):
3334
if trainer.use_amp:
3435
scaled_loss = trainer.scaler.scale(loss)
3536
scaled_loss.backward()
36-
trainer.scaler.step(trainer.optimizer)
37-
trainer.scaler.update()
37+
if 'lars' in trainer.optimizer.type:
38+
trainer.scaler.minimize(trainer.optimizer, scaled_loss)
39+
else:
40+
trainer.scaler.step(trainer.optimizer)
41+
trainer.scaler.update()
42+
else:
43+
loss.backward()
44+
if 'lars' in trainer.optimizer.type:
45+
trainer.optimizer.minimize(loss)
46+
else:
47+
trainer.optimizer.step()
3848

49+
if 'loss' not in trainer.outputs:
50+
trainer.outputs['loss'] = loss
51+
52+
53+
@HOOKS.register()
54+
class SimsiamOptimizerHook(Hook):
55+
def __init__(self, priority=1):
56+
self.priority = priority
57+
58+
def run_begin(self, trainer):
59+
if hasattr(trainer.model, '_layers'):
60+
model = trainer.model._layers
61+
else:
62+
model = trainer.model
63+
64+
# build simsiam optimizer
65+
trainer.optimizer = build_optimizer(
66+
trainer.cfg.optimizer, trainer.lr_scheduler, [model.encoder])
67+
trainer.predictor_optimizer = build_optimizer(
68+
trainer.cfg.optimizer, trainer.lr_scheduler.get_lr(), [model.predictor])
69+
70+
def train_iter_end(self, trainer):
71+
if 'Lars' in trainer.cfg['optimizer']['name']:
72+
trainer.optimizer.clear_gradients()
73+
trainer.predictor_optimizer.clear_gradients()
74+
else:
75+
trainer.optimizer.clear_grad()
76+
trainer.predictor_optimizer.clear_grad()
77+
78+
loss = 0
79+
loss = trainer.outputs['loss']
80+
81+
if trainer.use_amp:
82+
scaled_loss = trainer.scaler.scale(loss)
83+
scaled_loss.backward()
84+
if 'lars' in trainer.optimizer.type:
85+
trainer.scaler.minimize(trainer.optimizer, scaled_loss)
86+
trainer.scaler.minimize(trainer.predictor_optimizer, scaled_loss)
87+
else:
88+
trainer.scaler.step(trainer.optimizer)
89+
trainer.scaler.step(trainer.predictor_optimizer)
90+
trainer.scaler.update()
3991
else:
4092
loss.backward()
4193
if 'lars' in trainer.optimizer.type:
4294
trainer.optimizer.minimize(loss)
95+
trainer.predictor_optimizer.minimize(loss)
4396
else:
4497
trainer.optimizer.step()
98+
trainer.predictor_optimizer.step()
4599

46100
if 'loss' not in trainer.outputs:
47101
trainer.outputs['loss'] = loss

passl/modeling/architectures/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from .clas import Classification
1818
from .moco import MoCo
1919
from .simclr import SimCLR
20+
from .simsiam import SimSiam
2021
from .pixpro import PixPro
2122

2223
from .BEiTWrapper import BEiTWrapper, BEiTPTWrapper, BEiTFTWrapper

0 commit comments

Comments
 (0)