Skip to content

Commit 9a38416

Browse files
authored
Merge pull request #323 from rwightman/imagenet21k_datasets_more
BiT (Big Transfer) ResNetV2 models, Official ViT Hybrid R50 weights, VIT IN21K weights updated w/ repr layer, ImageNet21k and dataset / parser refactor
2 parents f8463b8 + 745bc5f commit 9a38416

31 files changed

+23859
-307
lines changed

README.md

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,19 @@
22

33
## What's New
44

5+
### Jan 25, 2021
6+
* Add ResNetV2 Big Transfer (BiT) models w/ ImageNet-1k and 21k weights from https://github.com/google-research/big_transfer
7+
* Add official R50+ViT-B/16 hybrid models + weights from https://github.com/google-research/vision_transformer
8+
* ImageNet-21k ViT weights are added w/ model defs and representation layer (pre logits) support
9+
* NOTE: ImageNet-21k classifier heads were zero'd in original weights, they are only useful for transfer learning
10+
* Add model defs and weights for DeiT Vision Transformer models from https://github.com/facebookresearch/deit
11+
* Refactor dataset classes into ImageDataset/IterableImageDataset + dataset specific parser classes
12+
* Add Tensorflow-Datasets (TFDS) wrapper to allow use of TFDS image classification sets with train script
13+
* Ex: `train.py /data/tfds --dataset tfds/oxford_iiit_pet --val-split test --model resnet50 -b 256 --amp --num-classes 37 --opt adamw --lr 3e-4 --weight-decay .001 --pretrained -j 2`
14+
* Add improved .tar dataset parser that reads images from .tar, folder of .tar files, or .tar within .tar
15+
* Run validation on full ImageNet-21k directly from tar w/ BiT model: `validate.py /data/fall11_whole.tar --model resnetv2_50x1_bitm_in21k --amp`
16+
* Models in this update should be stable w/ possible exception of ViT/BiT, possibility of some regressions with train/val scripts and dataset handling
17+
518
### Jan 3, 2021
619
* Add SE-ResNet-152D weights
720
* 256x256 val, 0.94 crop top-1 - 83.75
@@ -130,7 +143,9 @@ All model architecture families include variants with pretrained weights. The ar
130143

131144
A full version of the list below with source links can be found in the [documentation](https://rwightman.github.io/pytorch-image-models/models/).
132145

146+
* Big Transfer ResNetV2 (BiT) - https://arxiv.org/abs/1912.11370
133147
* CspNet (Cross-Stage Partial Networks) - https://arxiv.org/abs/1911.11929
148+
* DeiT (Vision Transformer) - https://arxiv.org/abs/2012.12877
134149
* DenseNet - https://arxiv.org/abs/1608.06993
135150
* DLA - https://arxiv.org/abs/1707.06484
136151
* DPN (Dual-Path Network) - https://arxiv.org/abs/1707.01629
@@ -242,6 +257,10 @@ One of the greatest assets of PyTorch is the community and their contributions.
242257
* Albumentations - https://github.com/albumentations-team/albumentations
243258
* Kornia - https://github.com/kornia/kornia
244259

260+
### Knowledge Distillation
261+
* RepDistiller - https://github.com/HobbitLong/RepDistiller
262+
* torchdistill - https://github.com/yoshitomo-matsubara/torchdistill
263+
245264
### Metric Learning
246265
* PyTorch Metric Learning - https://github.com/KevinMusgrave/pytorch-metric-learning
247266

docs/models.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,10 @@ Most included models have pretrained weights. The weights are either:
1010

1111
The validation results for the pretrained weights can be found [here](results.md)
1212

13+
## Big Transfer ResNetV2 (BiT) [[resnetv2.py](https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/resnetv2.py)]
14+
* Paper: `Big Transfer (BiT): General Visual Representation Learning` - https://arxiv.org/abs/1912.11370
15+
* Reference code: https://github.com/google-research/big_transfer
16+
1317
## Cross-Stage Partial Networks [[cspnet.py](https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/cspnet.py)]
1418
* Paper: `CSPNet: A New Backbone that can Enhance Learning Capability of CNN` - https://arxiv.org/abs/1911.11929
1519
* Reference impl: https://github.com/WongKinYiu/CrossStagePartialNetworks

inference.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import torch
1414

1515
from timm.models import create_model, apply_test_time_pool
16-
from timm.data import Dataset, create_loader, resolve_data_config
16+
from timm.data import ImageDataset, create_loader, resolve_data_config
1717
from timm.utils import AverageMeter, setup_default_logging
1818

1919
torch.backends.cudnn.benchmark = True
@@ -83,7 +83,7 @@ def main():
8383
model = model.cuda()
8484

8585
loader = create_loader(
86-
Dataset(args.data),
86+
ImageDataset(args.data),
8787
input_size=config['input_size'],
8888
batch_size=args.batch_size,
8989
use_prefetcher=True,

0 commit comments

Comments
 (0)