Skip to content

Commit a3a660f

Browse files
committed
init commit
0 parents  commit a3a660f

File tree

1,628 files changed

+126741
-0
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

1,628 files changed

+126741
-0
lines changed

LICENSE

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
MIT License
2+
3+
Copyright (c) 2021 Qi Han
4+
5+
Permission is hereby granted, free of charge, to any person obtaining a copy
6+
of this software and associated documentation files (the "Software"), to deal
7+
in the Software without restriction, including without limitation the rights
8+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9+
copies of the Software, and to permit persons to whom the Software is
10+
furnished to do so, subject to the following conditions:
11+
12+
The above copyright notice and this permission notice shall be included in all
13+
copies or substantial portions of the Software.
14+
15+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21+
SOFTWARE.

README.md

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
# Demysitifing Local Vision Transformer, [arxiv](https://arxiv.org/pdf/2106.04263.pdf)
2+
3+
This is the official PyTorch implementation of our paper. We simply replace local self attention by (dynamic) depth-wise convolution with lower computational cost. The performance is on par with the Swin Transformer.
4+
5+
Besides, the main contribution of our paper is the theorical and detailed comparison between depth-wise convolution and local self attention from three aspects: sparse connectivity, weight sharing and dynamic weight. By this paper, we want community to rethinking the local self attention and depth-wise convolution, and the basic model architeture designing rules.
6+
7+
<p align="center">
8+
<img width="600" height="300" src="figures/relation.png">
9+
</p>
10+
11+
Codes and models for object detection and semantic segmentation are avaliable in 'downstreams'.
12+
13+
We also give MLP based Swin Transformer models and Inhomogenous dynamic convolution in the ablation studies. These codes and models will coming soon.
14+
15+
16+
## Reference
17+
```
18+
@article{han2021demystifying,
19+
title={Demystifying Local Vision Transformer: Sparse Connectivity, Weight Sharing, and Dynamic Weight},
20+
author={Han, Qi and Fan, Zejia and Dai, Qi and Sun, Lei and Cheng, Ming-Ming and Liu, Jiaying and Wang, Jingdong},
21+
journal={arXiv preprint arXiv:2106.04263},
22+
year={2021}
23+
}
24+
```
25+
## 1. Requirements
26+
torch>=1.5.0, torchvision, [timm](https://github.com/rwightman/pytorch-image-models), pyyaml; apex-amp
27+
28+
data perpare: ImageNet dataset with the following structure:
29+
```
30+
│imagenet/
31+
├──train/
32+
│ ├── n01440764
33+
│ │ ├── n01440764_10026.JPEG
34+
│ │ ├── n01440764_10027.JPEG
35+
│ │ ├── ......
36+
│ ├── ......
37+
├──val/
38+
│ ├── n01440764
39+
│ │ ├── ILSVRC2012_val_00000293.JPEG
40+
│ │ ├── ILSVRC2012_val_00002138.JPEG
41+
│ │ ├── ......
42+
│ ├── ......
43+
```
44+
45+
## 2. Trainning
46+
For tiny model, we train with batch-size 128 on 8 GPUs. When trainning base model, we use batch-size 64 on 16 GPUs with OpenMPI to keep the total batch-size unchanged. (With the same trainning setting, the base model couldn't train with AMP due to the anomalous gradient values.)
47+
48+
Please change the data path in sh scripts first.
49+
50+
For tiny model:
51+
```bash
52+
bash scripts/run_dwnet_tiny_patch4_window7_224.sh
53+
bash scripts/run_dynamic_dwnet_tiny_patch4_window7_224.sh
54+
```
55+
For base model, use multi node with OpenMPI:
56+
```bash
57+
bash scripts/run_dwnet_base_patch4_window7_224.sh
58+
bash scripts/run_dynamic_dwnet_base_patch4_window7_224.sh
59+
```
60+
61+
## 3. Evaluation
62+
```
63+
python -m torch.distributed.launch --nproc_per_node 1 --master_port 12345 main.py --cfg configs/change_to_config_file --resume /path/to/model --data-path /path/to/imagenet --eval
64+
```
65+
66+
## 4. Models
67+
Models are provided by training on ImageNet with resolution 224.
68+
69+
| Model | #params | FLOPs | Top1 Acc| Download |
70+
| :--- | :---: | :---: | :---: | :---: |
71+
dwnet_tiny | 24M | 3.8G | 81.2 | [github](https://github.com/Atten4Vis/DemystifyLocalViT/releases/download/prerelease/dwnet_tiny_224.pth) |
72+
dynamic_dwnet_tiny | 51M | 3.8G | 81.8 | [github](https://github.com/Atten4Vis/DemystifyLocalViT/releases/download/prerelease/dynamic_dwnet_tiny_224.pth) |
73+
dwnet_base | 74M | 12.9G | 83.2 | [github](https://github.com/Atten4Vis/DemystifyLocalViT/releases/download/prerelease/dwnet_base_224.pth) |
74+
dynamic_dwnet_base | 162M | 13.0G | 83.2 | [github](https://github.com/Atten4Vis/DemystifyLocalViT/releases/download/prerelease/dynamic_dwnet_base_224.pth) |
75+
76+
## LICENSE
77+
This repo is under the MIT license. Some codes are borrow from [Swin Transformer](https://github.com/microsoft/Swin-Transformer).

config.py

Lines changed: 252 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,252 @@
1+
import os
2+
import yaml
3+
from yacs.config import CfgNode as CN
4+
5+
_C = CN()
6+
7+
# Base config files
8+
_C.BASE = ['']
9+
10+
# -----------------------------------------------------------------------------
11+
# Data settings
12+
# -----------------------------------------------------------------------------
13+
_C.DATA = CN()
14+
# Batch size for a single GPU, could be overwritten by command line argument
15+
_C.DATA.BATCH_SIZE = 128
16+
# Path to dataset, could be overwritten by command line argument
17+
_C.DATA.DATA_PATH = ''
18+
# Dataset name
19+
_C.DATA.DATASET = 'imagenet'
20+
# Input image size
21+
_C.DATA.IMG_SIZE = 224
22+
# Interpolation to resize image (random, bilinear, bicubic)
23+
_C.DATA.INTERPOLATION = 'bicubic'
24+
# Use zipped dataset instead of folder dataset
25+
# could be overwritten by command line argument
26+
_C.DATA.ZIP_MODE = False
27+
# Cache Data in Memory, could be overwritten by command line argument
28+
_C.DATA.CACHE_MODE = 'part'
29+
# Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.
30+
_C.DATA.PIN_MEMORY = True
31+
# Number of data loading threads
32+
_C.DATA.NUM_WORKERS = 8
33+
34+
# -----------------------------------------------------------------------------
35+
# Model settings
36+
# -----------------------------------------------------------------------------
37+
_C.MODEL = CN()
38+
# Model type
39+
_C.MODEL.TYPE = 'swin'
40+
# Model name
41+
_C.MODEL.NAME = 'swin_tiny_patch4_window7_224'
42+
# Checkpoint to resume, could be overwritten by command line argument
43+
_C.MODEL.RESUME = ''
44+
# Number of classes, overwritten in data preparation
45+
_C.MODEL.NUM_CLASSES = 1000
46+
# Dropout rate
47+
_C.MODEL.DROP_RATE = 0.0
48+
# Drop path rate
49+
_C.MODEL.DROP_PATH_RATE = 0.1
50+
# Label Smoothing
51+
_C.MODEL.LABEL_SMOOTHING = 0.1
52+
53+
# Swin Transformer parameters
54+
_C.MODEL.DWNET = CN()
55+
_C.MODEL.DWNET.PATCH_SIZE = 4
56+
_C.MODEL.DWNET.IN_CHANS = 3
57+
_C.MODEL.DWNET.EMBED_DIM = 96
58+
_C.MODEL.DWNET.DEPTHS = [2, 2, 6, 2]
59+
_C.MODEL.DWNET.WINDOW_SIZE = 7
60+
_C.MODEL.DWNET.MLP_RATIO = 4.
61+
_C.MODEL.DWNET.APE = False
62+
_C.MODEL.DWNET.PATCH_NORM = True
63+
_C.MODEL.DWNET.CONV_TYPE = "v1"
64+
_C.MODEL.DWNET.DYNAMIC = False
65+
66+
# halo Transformer parameters
67+
_C.MODEL.HALO = CN()
68+
_C.MODEL.HALO.PATCH_SIZE = 4
69+
_C.MODEL.HALO.IN_CHANS = 3
70+
_C.MODEL.HALO.EMBED_DIM = 96
71+
_C.MODEL.HALO.DEPTHS = [2, 2, 6, 2]
72+
_C.MODEL.HALO.NUM_HEADS = [3, 6, 12, 24]
73+
_C.MODEL.HALO.WINDOW_SIZE = [7, 7, 7, 7]
74+
_C.MODEL.HALO.HALO_SIZE = [3, 3, 3, 3]
75+
_C.MODEL.HALO.MLP_RATIO = 4.
76+
_C.MODEL.HALO.QKV_BIAS = True
77+
_C.MODEL.HALO.QK_SCALE = None
78+
_C.MODEL.HALO.APE = False
79+
_C.MODEL.HALO.PATCH_NORM = True
80+
81+
82+
# -----------------------------------------------------------------------------
83+
# Training settings
84+
# -----------------------------------------------------------------------------
85+
_C.TRAIN = CN()
86+
_C.TRAIN.START_EPOCH = 0
87+
_C.TRAIN.EPOCHS = 300
88+
_C.TRAIN.WARMUP_EPOCHS = 20
89+
_C.TRAIN.WEIGHT_DECAY = 0.05
90+
_C.TRAIN.BASE_LR = 5e-4
91+
_C.TRAIN.WARMUP_LR = 5e-7
92+
_C.TRAIN.MIN_LR = 5e-6
93+
# Clip gradient norm
94+
_C.TRAIN.CLIP_GRAD = 5.0
95+
# Auto resume from latest checkpoint
96+
_C.TRAIN.AUTO_RESUME = False
97+
# Gradient accumulation steps
98+
# could be overwritten by command line argument
99+
_C.TRAIN.ACCUMULATION_STEPS = 0
100+
# Whether to use gradient checkpointing to save memory
101+
# could be overwritten by command line argument
102+
_C.TRAIN.USE_CHECKPOINT = False
103+
104+
# LR scheduler
105+
_C.TRAIN.LR_SCHEDULER = CN()
106+
_C.TRAIN.LR_SCHEDULER.NAME = 'cosine'
107+
# Epoch interval to decay LR, used in StepLRScheduler
108+
_C.TRAIN.LR_SCHEDULER.DECAY_EPOCHS = 30
109+
# LR decay rate, used in StepLRScheduler
110+
_C.TRAIN.LR_SCHEDULER.DECAY_RATE = 0.1
111+
112+
# Optimizer
113+
_C.TRAIN.OPTIMIZER = CN()
114+
_C.TRAIN.OPTIMIZER.NAME = 'adamw'
115+
# Optimizer Epsilon
116+
_C.TRAIN.OPTIMIZER.EPS = 1e-8
117+
# Optimizer Betas
118+
_C.TRAIN.OPTIMIZER.BETAS = (0.9, 0.999)
119+
# SGD momentum
120+
_C.TRAIN.OPTIMIZER.MOMENTUM = 0.9
121+
122+
# -----------------------------------------------------------------------------
123+
# Augmentation settings
124+
# -----------------------------------------------------------------------------
125+
_C.AUG = CN()
126+
# Color jitter factor
127+
_C.AUG.COLOR_JITTER = 0.4
128+
# Use AutoAugment policy. "v0" or "original"
129+
_C.AUG.AUTO_AUGMENT = 'rand-m9-mstd0.5-inc1'
130+
# Random erase prob
131+
_C.AUG.REPROB = 0.25
132+
# Random erase mode
133+
_C.AUG.REMODE = 'pixel'
134+
# Random erase count
135+
_C.AUG.RECOUNT = 1
136+
# Mixup alpha, mixup enabled if > 0
137+
_C.AUG.MIXUP = 0.8
138+
# Cutmix alpha, cutmix enabled if > 0
139+
_C.AUG.CUTMIX = 1.0
140+
# Cutmix min/max ratio, overrides alpha and enables cutmix if set
141+
_C.AUG.CUTMIX_MINMAX = None
142+
# Probability of performing mixup or cutmix when either/both is enabled
143+
_C.AUG.MIXUP_PROB = 1.0
144+
# Probability of switching to cutmix when both mixup and cutmix enabled
145+
_C.AUG.MIXUP_SWITCH_PROB = 0.5
146+
# How to apply mixup/cutmix params. Per "batch", "pair", or "elem"
147+
_C.AUG.MIXUP_MODE = 'batch'
148+
149+
# -----------------------------------------------------------------------------
150+
# Testing settings
151+
# -----------------------------------------------------------------------------
152+
_C.TEST = CN()
153+
# Whether to use center crop when testing
154+
_C.TEST.CROP = True
155+
156+
# -----------------------------------------------------------------------------
157+
# Misc
158+
# -----------------------------------------------------------------------------
159+
# Mixed precision opt level, if O0, no amp is used ('O0', 'O1', 'O2')
160+
# overwritten by command line argument
161+
_C.AMP_OPT_LEVEL = ''
162+
# Path to output folder, overwritten by command line argument
163+
_C.OUTPUT = ''
164+
# Tag of experiment, overwritten by command line argument
165+
_C.TAG = 'default'
166+
# Frequency to save checkpoint
167+
_C.SAVE_FREQ = 1
168+
# Frequency to logging info
169+
_C.PRINT_FREQ = 10
170+
# Fixed random seed
171+
_C.SEED = 0
172+
# Perform evaluation only, overwritten by command line argument
173+
_C.EVAL_MODE = False
174+
# Test throughput only, overwritten by command line argument
175+
_C.THROUGHPUT_MODE = False
176+
# local rank for DistributedDataParallel, given by command line argument
177+
_C.LOCAL_RANK = 0
178+
179+
180+
def _update_config_from_file(config, cfg_file):
181+
config.defrost()
182+
with open(cfg_file, 'r') as f:
183+
yaml_cfg = yaml.load(f, Loader=yaml.FullLoader)
184+
185+
for cfg in yaml_cfg.setdefault('BASE', ['']):
186+
if cfg:
187+
_update_config_from_file(
188+
config, os.path.join(os.path.dirname(cfg_file), cfg)
189+
)
190+
print('=> merge config from {}'.format(cfg_file))
191+
config.merge_from_file(cfg_file)
192+
config.freeze()
193+
194+
195+
def update_config(config, args):
196+
_update_config_from_file(config, args.cfg)
197+
198+
config.defrost()
199+
if args.opts:
200+
config.merge_from_list(args.opts)
201+
202+
# merge from specific arguments
203+
if args.batch_size:
204+
config.DATA.BATCH_SIZE = args.batch_size
205+
if args.data_path:
206+
config.DATA.DATA_PATH = args.data_path
207+
if args.zip:
208+
config.DATA.ZIP_MODE = True
209+
if args.cache_mode:
210+
config.DATA.CACHE_MODE = args.cache_mode
211+
if args.resume:
212+
config.MODEL.RESUME = args.resume
213+
if args.accumulation_steps:
214+
config.TRAIN.ACCUMULATION_STEPS = args.accumulation_steps
215+
if args.use_checkpoint:
216+
config.TRAIN.USE_CHECKPOINT = True
217+
if args.amp_opt_level:
218+
config.AMP_OPT_LEVEL = args.amp_opt_level
219+
if args.output:
220+
config.OUTPUT = args.output
221+
if args.tag:
222+
config.TAG = args.tag
223+
if args.eval:
224+
config.EVAL_MODE = True
225+
if args.throughput:
226+
config.THROUGHPUT_MODE = True
227+
if args.data_set== 'CIFAR':
228+
config.DATA.DATASET='cifar'
229+
elif args.data_set == 'IMNET':
230+
config.DATA.DATASET='imagenet'
231+
if args.epoch!=300:
232+
config.TRAIN.EPOCHS=args.epoch
233+
234+
235+
# set local rank for distributed training
236+
config.LOCAL_RANK = args.local_rank
237+
238+
# output folder
239+
config.OUTPUT = os.path.join(config.OUTPUT, config.MODEL.NAME, config.TAG)
240+
241+
config.freeze()
242+
243+
244+
def get_config(args):
245+
"""Get a yacs CfgNode object with default values."""
246+
# Return a clone so that the defaults will not be altered
247+
# This is for the "local variable" use pattern
248+
249+
config = _C.clone()
250+
update_config(config, args)
251+
252+
return config
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
MODEL:
2+
TYPE: dwnet
3+
NAME: dwnet_base_patch4_window7_224
4+
DROP_PATH_RATE: 0.5
5+
DWNET:
6+
EMBED_DIM: 128
7+
DEPTHS: [ 2, 2, 18, 2 ]
8+
WINDOW_SIZE: 7
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
MODEL:
2+
TYPE: dwnet
3+
NAME: dwnet_tiny_patch4_window7_224
4+
DROP_PATH_RATE: 0.2
5+
DWNET:
6+
EMBED_DIM: 96
7+
DEPTHS: [ 2, 2, 6, 2 ]
8+
WINDOW_SIZE: 7
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
MODEL:
2+
TYPE: ddwnet
3+
NAME: ddwnet_base_patch4_window7_224
4+
DROP_PATH_RATE: 0.5
5+
DWNET:
6+
EMBED_DIM: 128
7+
DEPTHS: [ 2, 2, 18, 2 ]
8+
WINDOW_SIZE: 7
9+
DYNAMIC: True
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
MODEL:
2+
TYPE: ddwnet
3+
NAME: ddwnet_tiny_patch4_window7_224
4+
DROP_PATH_RATE: 0.2
5+
DWNET:
6+
EMBED_DIM: 96
7+
DEPTHS: [ 2, 2, 6, 2 ]
8+
WINDOW_SIZE: 7
9+
DYNAMIC: True

data/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .build import build_loader

0 commit comments

Comments
 (0)