Skip to content

Commit 8df4270

Browse files
committed
init
0 parents  commit 8df4270

File tree

117 files changed

+16299
-0
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

117 files changed

+16299
-0
lines changed

.gitignore

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
wandb
2+
coreml
3+
pretrain
4+
**/__pycache__
5+
pretrain
6+
ignore
7+
*.zip
8+
checkpoints
9+
trt

README.md

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
# [RepViT: Revisiting Mobile CNN From ViT Perspective](https://arxiv.org/abs/2307.09283)
2+
3+
Official PyTorch implementation of **RepViT**, from the following paper:
4+
5+
[RepViT: Revisiting Mobile CNN From ViT Perspective](https://arxiv.org/abs/2307.09283).\
6+
Ao Wang, Hui Chen, Zijia Lin, Hengjun Pu, and Guiguang Ding\
7+
[[`arXiv`](https://arxiv.org/abs/2307.09283)]
8+
9+
<p align="center">
10+
<img src="figures/latency.png" width=70%> <br>
11+
Models are trained on ImageNet-1K and deployed on iPhone 12 with Core ML Tools to get latency.
12+
</p>
13+
14+
<details>
15+
<summary>
16+
<font size="+1">Abstract</font>
17+
</summary>
18+
Recently, lightweight Vision Transformers (ViTs) demonstrate superior performance and lower latency compared with lightweight Convolutional Neural Networks (CNNs) on resource-constrained mobile devices. This improvement is usually attributed to the multi-head self-attention module, which enables the model to learn global representations. However, the architectural disparities between lightweight ViTs and lightweight CNNs have not been adequately examined. In this study, we revisit the efficient design of lightweight CNNs and emphasize their potential for mobile devices. We incrementally enhance the mobile-friendliness of a standard lightweight CNN, specifically MobileNetV3, by integrating the efficient architectural choices of lightweight ViTs. To this end, we present a new family of pure lightweight CNNs, namely RepViT. Extensive experiments show that RepViT outperforms existing state-of-the-art lightweight ViTs and exhibits favorable latency in various vision tasks. On ImageNet, RepViT achieves over 80\% top-1 accuracy with nearly 1ms latency on an iPhone 12, which is the first time for a lightweight model, to the best of our knowledge. Our largest model, RepViT-M3, obtains 81.4\% accuracy with only 1.3ms latency.
19+
</details>
20+
21+
<br>
22+
23+
## Classification on ImageNet-1K
24+
25+
### Models
26+
27+
| Model | Top-1 (300)| #params | MACs | Latency | Ckpt | Core ML | Log |
28+
|:---------------|:----:|:---:|:--:|:--:|:--:|:--:|:--:|
29+
| RepViT-M1 | 78.5 | 5.1M | 0.8G | 0.9ms | [M1](https://github.com/jameslahm/RepViT/releases/download/untagged-75eb9e1fea235b938f50/repvit_m1_distill_300.pth) | [M1](https://github.com/jameslahm/RepViT/releases/download/untagged-75eb9e1fea235b938f50/repvit_m1_224.mlmodel) | [M1](./logs/repvit_m1_train.log) |
30+
| RepViT-M2 | 80.6 | 8.8M | 1.4G | 1.1ms | [M2](https://github.com/jameslahm/RepViT/releases/download/untagged-75eb9e1fea235b938f50/repvit_m2_distill_300.pth) | [M2](https://github.com/jameslahm/RepViT/releases/download/untagged-75eb9e1fea235b938f50/repvit_m2_224.mlmodel) | [M2](./logs/repvit_m2_train.log) |
31+
| RepViT-M3 | 81.4 | 10.1M | 1.9G | 1.3ms | [M3](https://github.com/jameslahm/RepViT/releases/download/untagged-75eb9e1fea235b938f50/repvit_m3_distill_300.pth) | [M3](https://github.com/jameslahm/RepViT/releases/download/untagged-75eb9e1fea235b938f50/repvit_m3_224.mlmodel) | [M3](./logs/repvit_m3_train.log) |
32+
33+
Tips: Convert a training-time RepViT into the inference-time structure
34+
```
35+
from timm.models import create_model
36+
import utils
37+
38+
model = create_model('repvit_m1')
39+
utils.replace_batchnorm(model)
40+
```
41+
42+
## Latency Measurement
43+
44+
The latency reported in RepViT for iPhone 12 (iOS 16) uses the benchmark tool from [XCode 14](https://developer.apple.com/videos/play/wwdc2022/10027/).
45+
For example, here is a latency measurement of RepViT-M1:
46+
47+
![](./figures/repvit_m1_latency.png)
48+
49+
Tips: export the model to Core ML model
50+
```
51+
python export_coreml.py --model repvit_m1 --ckpt pretrain/repvit_m1_distill_300.pth
52+
```
53+
Tips: measure the throughput on GPU
54+
```
55+
python speed_gpu.py --model repvit_m1
56+
```
57+
58+
59+
## ImageNet
60+
61+
### Prerequisites
62+
`conda` virtual environment is recommended.
63+
```
64+
conda create -n repvit python=3.8
65+
pip install -r requirements.txt
66+
```
67+
68+
### Data preparation
69+
70+
Download and extract ImageNet train and val images from http://image-net.org/. The training and validation data are expected to be in the `train` folder and `val` folder respectively:
71+
```
72+
|-- /path/to/imagenet/
73+
|-- train
74+
|-- val
75+
```
76+
77+
### Training
78+
To train RepViT-M1 on an 8-GPU machine:
79+
80+
```
81+
python -m torch.distributed.launch --nproc_per_node=8 --master_port 12346 --use_env main.py --model repvit_m1 --data-path ~/imagenet --dist-eval
82+
```
83+
Tips: specify your data path and model name!
84+
85+
### Testing
86+
For example, to test RepViT-M1:
87+
```
88+
python main.py --eval --model repvit_m3 --resume pretrain/repvit_m3_distill_300.pth --data-path ~/imagenet
89+
```
90+
91+
## Downstream Tasks
92+
[Object Detection and Instance Segmentation](detection/README.md)<br>
93+
[Semantic Segmentation](segmentation/README.md)
94+
95+
## Acknowledgement
96+
97+
Classification (ImageNet) code base is partly built with [LeViT](https://github.com/facebookresearch/LeViT), [PoolFormer](https://github.com/sail-sg/poolformer) and [EfficientFormer](https://github.com/snap-research/EfficientFormer).
98+
99+
The detection and segmentation pipeline is from [MMCV](https://github.com/open-mmlab/mmcv) ([MMDetection](https://github.com/open-mmlab/mmdetection) and [MMSegmentation](https://github.com/open-mmlab/mmsegmentation)).
100+
101+
Thanks for the great implementations!
102+
103+
## Citation
104+
105+
If our code or models help your work, please cite our papers:
106+
```BibTeX
107+
108+
```

data/__init__.py

Whitespace-only changes.

data/datasets.py

Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
'''
2+
Build trainining/testing datasets
3+
'''
4+
import os
5+
import json
6+
7+
from torchvision import datasets, transforms
8+
from torchvision.datasets.folder import ImageFolder, default_loader
9+
import torch
10+
11+
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
12+
from timm.data import create_transform
13+
14+
try:
15+
from timm.data import TimmDatasetTar
16+
except ImportError:
17+
# for higher version of timm
18+
from timm.data import ImageDataset as TimmDatasetTar
19+
20+
class INatDataset(ImageFolder):
21+
def __init__(self, root, train=True, year=2018, transform=None, target_transform=None,
22+
category='name', loader=default_loader):
23+
self.transform = transform
24+
self.loader = loader
25+
self.target_transform = target_transform
26+
self.year = year
27+
# assert category in ['kingdom','phylum','class','order','supercategory','family','genus','name']
28+
path_json = os.path.join(
29+
root, f'{"train" if train else "val"}{year}.json')
30+
with open(path_json) as json_file:
31+
data = json.load(json_file)
32+
33+
with open(os.path.join(root, 'categories.json')) as json_file:
34+
data_catg = json.load(json_file)
35+
36+
path_json_for_targeter = os.path.join(root, f"train{year}.json")
37+
38+
with open(path_json_for_targeter) as json_file:
39+
data_for_targeter = json.load(json_file)
40+
41+
targeter = {}
42+
indexer = 0
43+
for elem in data_for_targeter['annotations']:
44+
king = []
45+
king.append(data_catg[int(elem['category_id'])][category])
46+
if king[0] not in targeter.keys():
47+
targeter[king[0]] = indexer
48+
indexer += 1
49+
self.nb_classes = len(targeter)
50+
51+
self.samples = []
52+
for elem in data['images']:
53+
cut = elem['file_name'].split('/')
54+
target_current = int(cut[2])
55+
path_current = os.path.join(root, cut[0], cut[2], cut[3])
56+
57+
categors = data_catg[target_current]
58+
target_current_true = targeter[categors[category]]
59+
self.samples.append((path_current, target_current_true))
60+
61+
# __getitem__ and __len__ inherited from ImageFolder
62+
63+
64+
def build_dataset(is_train, args):
65+
transform = build_transform(is_train, args)
66+
67+
if args.data_set == 'CIFAR':
68+
dataset = datasets.CIFAR100(
69+
args.data_path, train=is_train, transform=transform)
70+
nb_classes = 100
71+
elif args.data_set == 'IMNET':
72+
prefix = 'train' if is_train else 'val'
73+
data_dir = os.path.join(args.data_path, f'{prefix}.tar')
74+
if os.path.exists(data_dir):
75+
dataset = TimmDatasetTar(data_dir, transform=transform)
76+
else:
77+
root = os.path.join(args.data_path, 'train' if is_train else 'val')
78+
dataset = datasets.ImageFolder(root, transform=transform)
79+
nb_classes = 1000
80+
elif args.data_set == 'IMNETEE':
81+
root = os.path.join(args.data_path, 'train' if is_train else 'val')
82+
dataset = datasets.ImageFolder(root, transform=transform)
83+
nb_classes = 10
84+
elif args.data_set == 'FLOWERS':
85+
root = os.path.join(args.data_path, 'train' if is_train else 'test')
86+
dataset = datasets.ImageFolder(root, transform=transform)
87+
if is_train:
88+
dataset = torch.utils.data.ConcatDataset(
89+
[dataset for _ in range(100)])
90+
nb_classes = 102
91+
elif args.data_set == 'INAT':
92+
dataset = INatDataset(args.data_path, train=is_train, year=2018,
93+
category=args.inat_category, transform=transform)
94+
nb_classes = dataset.nb_classes
95+
elif args.data_set == 'INAT19':
96+
dataset = INatDataset(args.data_path, train=is_train, year=2019,
97+
category=args.inat_category, transform=transform)
98+
nb_classes = dataset.nb_classes
99+
return dataset, nb_classes
100+
101+
102+
def build_transform(is_train, args):
103+
resize_im = args.input_size > 32
104+
if is_train:
105+
# this should always dispatch to transforms_imagenet_train
106+
transform = create_transform(
107+
input_size=args.input_size,
108+
is_training=True,
109+
color_jitter=args.color_jitter,
110+
auto_augment=args.aa,
111+
interpolation=args.train_interpolation,
112+
re_prob=args.reprob,
113+
re_mode=args.remode,
114+
re_count=args.recount,
115+
)
116+
if not resize_im:
117+
# replace RandomResizedCropAndInterpolation with
118+
# RandomCrop
119+
transform.transforms[0] = transforms.RandomCrop(
120+
args.input_size, padding=4)
121+
return transform
122+
123+
t = []
124+
if args.finetune:
125+
t.append(
126+
transforms.Resize((args.input_size, args.input_size),
127+
interpolation=3)
128+
)
129+
else:
130+
if resize_im:
131+
size = int((256 / 224) * args.input_size)
132+
t.append(
133+
# to maintain same ratio w.r.t. 224 images
134+
transforms.Resize(size, interpolation=3),
135+
)
136+
t.append(transforms.CenterCrop(args.input_size))
137+
138+
t.append(transforms.ToTensor())
139+
t.append(transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD))
140+
return transforms.Compose(t)

data/samplers.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
'''
2+
Build samplers for data loading
3+
'''
4+
import torch
5+
import torch.distributed as dist
6+
import math
7+
8+
9+
class RASampler(torch.utils.data.Sampler):
10+
"""Sampler that restricts data loading to a subset of the dataset for distributed,
11+
with repeated augmentation.
12+
It ensures that different each augmented version of a sample will be visible to a
13+
different process (GPU)
14+
Heavily based on torch.utils.data.DistributedSampler
15+
"""
16+
17+
def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True):
18+
if num_replicas is None:
19+
if not dist.is_available():
20+
raise RuntimeError(
21+
"Requires distributed package to be available")
22+
num_replicas = dist.get_world_size()
23+
if rank is None:
24+
if not dist.is_available():
25+
raise RuntimeError(
26+
"Requires distributed package to be available")
27+
rank = dist.get_rank()
28+
self.dataset = dataset
29+
self.num_replicas = num_replicas
30+
self.rank = rank
31+
self.epoch = 0
32+
self.num_samples = int(
33+
math.ceil(len(self.dataset) * 3.0 / self.num_replicas))
34+
self.total_size = self.num_samples * self.num_replicas
35+
# self.num_selected_samples = int(math.ceil(len(self.dataset) / self.num_replicas))
36+
self.num_selected_samples = int(math.floor(
37+
len(self.dataset) // 256 * 256 / self.num_replicas))
38+
self.shuffle = shuffle
39+
40+
def __iter__(self):
41+
# deterministically shuffle based on epoch
42+
g = torch.Generator()
43+
g.manual_seed(self.epoch)
44+
if self.shuffle:
45+
indices = torch.randperm(len(self.dataset), generator=g).tolist()
46+
else:
47+
indices = list(range(len(self.dataset)))
48+
49+
# add extra samples to make it evenly divisible
50+
indices = [ele for ele in indices for i in range(3)]
51+
indices += indices[:(self.total_size - len(indices))]
52+
assert len(indices) == self.total_size
53+
54+
# subsample
55+
indices = indices[self.rank:self.total_size:self.num_replicas]
56+
assert len(indices) == self.num_samples
57+
58+
return iter(indices[:self.num_selected_samples])
59+
60+
def __len__(self):
61+
return self.num_selected_samples
62+
63+
def set_epoch(self, epoch):
64+
self.epoch = epoch

0 commit comments

Comments
 (0)