Skip to content

Commit 98047ef

Browse files
committed
Add EVA FT results, hopefully fix BEiT test failures
1 parent 3cc4d7a commit 98047ef

File tree

5 files changed

+96
-120
lines changed

5 files changed

+96
-120
lines changed

README.md

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,16 @@ And a big thanks to all GitHub sponsors who helped with some of my costs before
2222
## What's New
2323

2424
# Dec 6, 2022
25-
* Add 'EVA g', BEiT style ViT-g/14 model weights w/ both MIM pretrain and CLIP pretrain from https://github.com/baaivision/EVA
25+
* Add 'EVA g', BEiT style ViT-g/14 model weights w/ both MIM pretrain and CLIP pretrain to `beit.py`.
26+
* original source: https://github.com/baaivision/EVA
27+
* paper: https://arxiv.org/abs/2211.07636
28+
29+
| model | top1 | param_count | gmac | macts | hub |
30+
|:-----------------------------------------|-------:|--------------:|-------:|--------:|:----------------------------------------|
31+
| eva_giant_patch14_560.m30m_ft_in22k_in1k | 89.8 | 1014.4 | 1906.8 | 2577.2 | [link](https://huggingface.co/BAAI/EVA) |
32+
| eva_giant_patch14_336.m30m_ft_in22k_in1k | 89.6 | 1013 | 620.6 | 550.7 | [link](https://huggingface.co/BAAI/EVA) |
33+
| eva_giant_patch14_336.clip_ft_in1k | 89.4 | 1013 | 620.6 | 550.7 | [link](https://huggingface.co/BAAI/EVA) |
34+
| eva_giant_patch14_224.clip_ft_in1k | 89.1 | 1012.6 | 267.2 | 192.6 | [link](https://huggingface.co/BAAI/EVA) |
2635

2736
# Dec 5, 2022
2837

benchmark.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,9 +80,11 @@
8080
parser.add_argument('--results-format', default='csv', type=str,
8181
help='Format for results file one of (csv, json) (default: csv).')
8282
parser.add_argument('--num-warm-iter', default=10, type=int,
83-
metavar='N', help='Number of warmup iterations (default: 10)')
83+
help='Number of warmup iterations (default: 10)')
8484
parser.add_argument('--num-bench-iter', default=40, type=int,
85-
metavar='N', help='Number of benchmark iterations (default: 40)')
85+
help='Number of benchmark iterations (default: 40)')
86+
parser.add_argument('--device', default='cuda', type=str,
87+
help="device to run benchmark on")
8688

8789
# common inference / train args
8890
parser.add_argument('--model', '-m', metavar='NAME', default='resnet50',

tests/test_models.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
'vit_*', 'tnt_*', 'pit_*', 'swin_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*',
2828
'convit_*', 'levit*', 'visformer*', 'deit*', 'jx_nest_*', 'nest_*', 'xcit_*', 'crossvit_*', 'beit*',
2929
'poolformer_*', 'volo_*', 'sequencer2d_*', 'swinv2_*', 'pvt_v2*', 'mvitv2*', 'gcvit*', 'efficientformer*',
30-
'coatnet*', 'coatnext*', 'maxvit*', 'maxxvit*',
30+
'coatnet*', 'coatnext*', 'maxvit*', 'maxxvit*', 'eva_*'
3131
]
3232
NUM_NON_STD = len(NON_STD_FILTERS)
3333

@@ -39,7 +39,7 @@
3939
'*nfnet_f3*', '*nfnet_f4*', '*nfnet_f5*', '*nfnet_f6*', '*nfnet_f7*', '*efficientnetv2_xl*',
4040
'*resnetrs350*', '*resnetrs420*', 'xcit_large_24_p8*', 'vit_huge*', 'vit_gi*', 'swin*huge*',
4141
'swin*giant*']
42-
NON_STD_EXCLUDE_FILTERS = ['vit_huge*', 'vit_gi*', 'swin*giant*']
42+
NON_STD_EXCLUDE_FILTERS = ['vit_huge*', 'vit_gi*', 'swin*giant*', 'eva_giant*']
4343
else:
4444
EXCLUDE_FILTERS = []
4545
NON_STD_EXCLUDE_FILTERS = ['vit_gi*']

timm/models/beit.py

Lines changed: 78 additions & 114 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
""" BEIT: BERT Pre-Training of Image Transformers (https://arxiv.org/abs/2106.08254)
1+
""" BEiT: BERT Pre-Training of Image Transformers (https://arxiv.org/abs/2106.08254)
22
33
Model from official source: https://github.com/microsoft/unilm/tree/master/beit
44
@@ -68,82 +68,6 @@
6868
from .vision_transformer import checkpoint_filter_fn
6969

7070

71-
def _cfg(url='', **kwargs):
72-
return {
73-
'url': url,
74-
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
75-
'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
76-
'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5),
77-
'first_conv': 'patch_embed.proj', 'classifier': 'head',
78-
**kwargs
79-
}
80-
81-
82-
default_cfgs = generate_default_cfgs({
83-
'beit_base_patch16_224.in22k_ft_in22k_in1k': _cfg(
84-
url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_base_patch16_224_pt22k_ft22kto1k.pth'),
85-
'beit_base_patch16_384.in22k_ft_in22k_in1k': _cfg(
86-
url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_base_patch16_384_pt22k_ft22kto1k.pth',
87-
input_size=(3, 384, 384), crop_pct=1.0,
88-
),
89-
'beit_base_patch16_224.in22k_ft_in22k': _cfg(
90-
url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_base_patch16_224_pt22k_ft22k.pth',
91-
num_classes=21841,
92-
),
93-
'beit_large_patch16_224.in22k_ft_in22k_in1k': _cfg(
94-
url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_224_pt22k_ft22kto1k.pth'),
95-
'beit_large_patch16_384.in22k_ft_in22k_in1k': _cfg(
96-
url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_384_pt22k_ft22kto1k.pth',
97-
input_size=(3, 384, 384), crop_pct=1.0,
98-
),
99-
'beit_large_patch16_512.in22k_ft_in22k_in1k': _cfg(
100-
url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_512_pt22k_ft22kto1k.pth',
101-
input_size=(3, 512, 512), crop_pct=1.0,
102-
),
103-
'beit_large_patch16_224.in22k_ft_in22k': _cfg(
104-
url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_224_pt22k_ft22k.pth',
105-
num_classes=21841,
106-
),
107-
108-
'beitv2_base_patch16_224.in1k_ft_in22k_in1k': _cfg(
109-
url='https://conversationhub.blob.core.windows.net/beit-share-public/beitv2/beitv2_base_patch16_224_pt1k_ft21kto1k.pth',
110-
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD
111-
),
112-
'beitv2_base_patch16_224.in1k_ft_in22k': _cfg(
113-
url='https://conversationhub.blob.core.windows.net/beit-share-public/beitv2/beitv2_base_patch16_224_pt1k_ft21k.pth',
114-
num_classes=21841,
115-
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD
116-
),
117-
'beitv2_large_patch16_224.in1k_ft_in22k_in1k': _cfg(
118-
url='https://conversationhub.blob.core.windows.net/beit-share-public/beitv2/beitv2_large_patch16_224_pt1k_ft21kto1k.pth',
119-
crop_pct=0.95,
120-
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD
121-
),
122-
'beitv2_large_patch16_224.in1k_ft_in22k': _cfg(
123-
url='https://conversationhub.blob.core.windows.net/beit-share-public/beitv2/beitv2_large_patch16_224_pt1k_ft21k.pth',
124-
num_classes=21841,
125-
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD
126-
),
127-
128-
'eva_giant_patch14_224.clip_ft_in1k': _cfg(
129-
hf_hub_id='BAAI/EVA', hf_hub_filename='eva_clip_vis_enc_sz224_ftcls_89p1.pt',
130-
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
131-
),
132-
'eva_giant_patch14_336.clip_ft_in1k': _cfg(
133-
hf_hub_id='BAAI/EVA', hf_hub_filename='eva_clip_vis_enc_sz336_ftcls_89p4.pt',
134-
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
135-
input_size=(3, 336, 336)),
136-
'eva_giant_patch14_336.m30m_ft_in22k_in1k': _cfg(
137-
hf_hub_id='BAAI/EVA', hf_hub_filename='eva_21k_1k_336px_psz14_ema_89p6.pt',
138-
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD,
139-
input_size=(3, 336, 336)),
140-
'eva_giant_patch14_560.m30m_ft_in22k_in1k': _cfg(
141-
hf_hub_id='BAAI/EVA', hf_hub_filename='eva_21k_1k_560px_psz14_ema_89p7.pt',
142-
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD,
143-
input_size=(3, 560, 560)),
144-
})
145-
146-
14771
def gen_relative_position_index(window_size: Tuple[int, int]) -> torch.Tensor:
14872
num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
14973
# cls to token & token 2 cls & cls to cls
@@ -416,6 +340,82 @@ def forward(self, x):
416340
return x
417341

418342

343+
def _cfg(url='', **kwargs):
344+
return {
345+
'url': url,
346+
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
347+
'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
348+
'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5),
349+
'first_conv': 'patch_embed.proj', 'classifier': 'head',
350+
**kwargs
351+
}
352+
353+
354+
default_cfgs = generate_default_cfgs({
355+
'beit_base_patch16_224.in22k_ft_in22k_in1k': _cfg(
356+
url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_base_patch16_224_pt22k_ft22kto1k.pth'),
357+
'beit_base_patch16_384.in22k_ft_in22k_in1k': _cfg(
358+
url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_base_patch16_384_pt22k_ft22kto1k.pth',
359+
input_size=(3, 384, 384), crop_pct=1.0,
360+
),
361+
'beit_base_patch16_224.in22k_ft_in22k': _cfg(
362+
url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_base_patch16_224_pt22k_ft22k.pth',
363+
num_classes=21841,
364+
),
365+
'beit_large_patch16_224.in22k_ft_in22k_in1k': _cfg(
366+
url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_224_pt22k_ft22kto1k.pth'),
367+
'beit_large_patch16_384.in22k_ft_in22k_in1k': _cfg(
368+
url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_384_pt22k_ft22kto1k.pth',
369+
input_size=(3, 384, 384), crop_pct=1.0,
370+
),
371+
'beit_large_patch16_512.in22k_ft_in22k_in1k': _cfg(
372+
url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_512_pt22k_ft22kto1k.pth',
373+
input_size=(3, 512, 512), crop_pct=1.0,
374+
),
375+
'beit_large_patch16_224.in22k_ft_in22k': _cfg(
376+
url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_224_pt22k_ft22k.pth',
377+
num_classes=21841,
378+
),
379+
380+
'beitv2_base_patch16_224.in1k_ft_in22k_in1k': _cfg(
381+
url='https://conversationhub.blob.core.windows.net/beit-share-public/beitv2/beitv2_base_patch16_224_pt1k_ft21kto1k.pth',
382+
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD
383+
),
384+
'beitv2_base_patch16_224.in1k_ft_in22k': _cfg(
385+
url='https://conversationhub.blob.core.windows.net/beit-share-public/beitv2/beitv2_base_patch16_224_pt1k_ft21k.pth',
386+
num_classes=21841,
387+
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD
388+
),
389+
'beitv2_large_patch16_224.in1k_ft_in22k_in1k': _cfg(
390+
url='https://conversationhub.blob.core.windows.net/beit-share-public/beitv2/beitv2_large_patch16_224_pt1k_ft21kto1k.pth',
391+
crop_pct=0.95,
392+
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD
393+
),
394+
'beitv2_large_patch16_224.in1k_ft_in22k': _cfg(
395+
url='https://conversationhub.blob.core.windows.net/beit-share-public/beitv2/beitv2_large_patch16_224_pt1k_ft21k.pth',
396+
num_classes=21841,
397+
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD
398+
),
399+
400+
'eva_giant_patch14_224.clip_ft_in1k': _cfg(
401+
hf_hub_id='BAAI/EVA', hf_hub_filename='eva_clip_vis_enc_sz224_ftcls_89p1.pt',
402+
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0,
403+
),
404+
'eva_giant_patch14_336.clip_ft_in1k': _cfg(
405+
hf_hub_id='BAAI/EVA', hf_hub_filename='eva_clip_vis_enc_sz336_ftcls_89p4.pt',
406+
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
407+
input_size=(3, 336, 336), crop_pct=1.0, crop_mode='squash'),
408+
'eva_giant_patch14_336.m30m_ft_in22k_in1k': _cfg(
409+
hf_hub_id='BAAI/EVA', hf_hub_filename='eva_21k_1k_336px_psz14_ema_89p6.pt',
410+
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD,
411+
input_size=(3, 336, 336), crop_pct=1.0, crop_mode='squash'),
412+
'eva_giant_patch14_560.m30m_ft_in22k_in1k': _cfg(
413+
hf_hub_id='BAAI/EVA', hf_hub_filename='eva_21k_1k_560px_psz14_ema_89p7.pt',
414+
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD,
415+
input_size=(3, 560, 560), crop_pct=1.0, crop_mode='squash'),
416+
})
417+
418+
419419
def _beit_checkpoint_filter_fn(state_dict, model):
420420
if 'module' in state_dict:
421421
# beit v2 didn't strip module
@@ -425,7 +425,7 @@ def _beit_checkpoint_filter_fn(state_dict, model):
425425

426426
def _create_beit(variant, pretrained=False, **kwargs):
427427
if kwargs.get('features_only', None):
428-
raise RuntimeError('features_only not implemented for Beit models.')
428+
raise RuntimeError('features_only not implemented for BEiT models.')
429429

430430
model = build_model_with_cfg(
431431
Beit, variant, pretrained,
@@ -453,15 +453,6 @@ def beit_base_patch16_384(pretrained=False, **kwargs):
453453
return model
454454

455455

456-
@register_model
457-
def beit_base_patch16_224_in22k(pretrained=False, **kwargs):
458-
model_kwargs = dict(
459-
patch_size=16, embed_dim=768, depth=12, num_heads=12,
460-
use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=0.1, **kwargs)
461-
model = _create_beit('beit_base_patch16_224_in22k', pretrained=pretrained, **model_kwargs)
462-
return model
463-
464-
465456
@register_model
466457
def beit_large_patch16_224(pretrained=False, **kwargs):
467458
model_kwargs = dict(
@@ -489,15 +480,6 @@ def beit_large_patch16_512(pretrained=False, **kwargs):
489480
return model
490481

491482

492-
@register_model
493-
def beit_large_patch16_224_in22k(pretrained=False, **kwargs):
494-
model_kwargs = dict(
495-
patch_size=16, embed_dim=1024, depth=24, num_heads=16,
496-
use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=1e-5, **kwargs)
497-
model = _create_beit('beit_large_patch16_224_in22k', pretrained=pretrained, **model_kwargs)
498-
return model
499-
500-
501483
@register_model
502484
def beitv2_base_patch16_224(pretrained=False, **kwargs):
503485
model_kwargs = dict(
@@ -507,15 +489,6 @@ def beitv2_base_patch16_224(pretrained=False, **kwargs):
507489
return model
508490

509491

510-
@register_model
511-
def beitv2_base_patch16_224_in22k(pretrained=False, **kwargs):
512-
model_kwargs = dict(
513-
patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4,
514-
use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=1e-5, **kwargs)
515-
model = _create_beit('beitv2_base_patch16_224_in22k', pretrained=pretrained, **model_kwargs)
516-
return model
517-
518-
519492
@register_model
520493
def beitv2_large_patch16_224(pretrained=False, **kwargs):
521494
model_kwargs = dict(
@@ -525,15 +498,6 @@ def beitv2_large_patch16_224(pretrained=False, **kwargs):
525498
return model
526499

527500

528-
@register_model
529-
def beitv2_large_patch16_224_in22k(pretrained=False, **kwargs):
530-
model_kwargs = dict(
531-
patch_size=16, embed_dim=1024, depth=24, num_heads=16,
532-
use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=1e-5, **kwargs)
533-
model = _create_beit('beitv2_large_patch16_224_in22k', pretrained=pretrained, **model_kwargs)
534-
return model
535-
536-
537501
@register_model
538502
def eva_giant_patch14_224(pretrained=False, **kwargs):
539503
""" EVA-g model https://arxiv.org/abs/2211.07636 """

timm/models/pretrained.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,10 +59,11 @@ def to_dict(self, remove_source=False, remove_null=True):
5959

6060
def filter_pretrained_cfg(cfg, remove_source=False, remove_null=True):
6161
filtered_cfg = {}
62+
keep_none = {'pool_size', 'first_conv', 'classifier'} # always keep these keys, even if none
6263
for k, v in cfg.items():
6364
if remove_source and k in {'url', 'file', 'hf_hub_id', 'hf_hub_id', 'hf_hub_filename', 'source'}:
6465
continue
65-
if remove_null and v is None:
66+
if remove_null and v is None and k not in keep_none:
6667
continue
6768
filtered_cfg[k] = v
6869
return filtered_cfg

0 commit comments

Comments
 (0)