1- """ BEIT : BERT Pre-Training of Image Transformers (https://arxiv.org/abs/2106.08254)
1+ """ BEiT : BERT Pre-Training of Image Transformers (https://arxiv.org/abs/2106.08254)
22
33Model from official source: https://github.com/microsoft/unilm/tree/master/beit
44
6868from .vision_transformer import checkpoint_filter_fn
6969
7070
71- def _cfg (url = '' , ** kwargs ):
72- return {
73- 'url' : url ,
74- 'num_classes' : 1000 , 'input_size' : (3 , 224 , 224 ), 'pool_size' : None ,
75- 'crop_pct' : .9 , 'interpolation' : 'bicubic' , 'fixed_input_size' : True ,
76- 'mean' : (0.5 , 0.5 , 0.5 ), 'std' : (0.5 , 0.5 , 0.5 ),
77- 'first_conv' : 'patch_embed.proj' , 'classifier' : 'head' ,
78- ** kwargs
79- }
80-
81-
82- default_cfgs = generate_default_cfgs ({
83- 'beit_base_patch16_224.in22k_ft_in22k_in1k' : _cfg (
84- url = 'https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_base_patch16_224_pt22k_ft22kto1k.pth' ),
85- 'beit_base_patch16_384.in22k_ft_in22k_in1k' : _cfg (
86- url = 'https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_base_patch16_384_pt22k_ft22kto1k.pth' ,
87- input_size = (3 , 384 , 384 ), crop_pct = 1.0 ,
88- ),
89- 'beit_base_patch16_224.in22k_ft_in22k' : _cfg (
90- url = 'https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_base_patch16_224_pt22k_ft22k.pth' ,
91- num_classes = 21841 ,
92- ),
93- 'beit_large_patch16_224.in22k_ft_in22k_in1k' : _cfg (
94- url = 'https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_224_pt22k_ft22kto1k.pth' ),
95- 'beit_large_patch16_384.in22k_ft_in22k_in1k' : _cfg (
96- url = 'https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_384_pt22k_ft22kto1k.pth' ,
97- input_size = (3 , 384 , 384 ), crop_pct = 1.0 ,
98- ),
99- 'beit_large_patch16_512.in22k_ft_in22k_in1k' : _cfg (
100- url = 'https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_512_pt22k_ft22kto1k.pth' ,
101- input_size = (3 , 512 , 512 ), crop_pct = 1.0 ,
102- ),
103- 'beit_large_patch16_224.in22k_ft_in22k' : _cfg (
104- url = 'https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_224_pt22k_ft22k.pth' ,
105- num_classes = 21841 ,
106- ),
107-
108- 'beitv2_base_patch16_224.in1k_ft_in22k_in1k' : _cfg (
109- url = 'https://conversationhub.blob.core.windows.net/beit-share-public/beitv2/beitv2_base_patch16_224_pt1k_ft21kto1k.pth' ,
110- mean = IMAGENET_DEFAULT_MEAN , std = IMAGENET_DEFAULT_STD
111- ),
112- 'beitv2_base_patch16_224.in1k_ft_in22k' : _cfg (
113- url = 'https://conversationhub.blob.core.windows.net/beit-share-public/beitv2/beitv2_base_patch16_224_pt1k_ft21k.pth' ,
114- num_classes = 21841 ,
115- mean = IMAGENET_DEFAULT_MEAN , std = IMAGENET_DEFAULT_STD
116- ),
117- 'beitv2_large_patch16_224.in1k_ft_in22k_in1k' : _cfg (
118- url = 'https://conversationhub.blob.core.windows.net/beit-share-public/beitv2/beitv2_large_patch16_224_pt1k_ft21kto1k.pth' ,
119- crop_pct = 0.95 ,
120- mean = IMAGENET_DEFAULT_MEAN , std = IMAGENET_DEFAULT_STD
121- ),
122- 'beitv2_large_patch16_224.in1k_ft_in22k' : _cfg (
123- url = 'https://conversationhub.blob.core.windows.net/beit-share-public/beitv2/beitv2_large_patch16_224_pt1k_ft21k.pth' ,
124- num_classes = 21841 ,
125- mean = IMAGENET_DEFAULT_MEAN , std = IMAGENET_DEFAULT_STD
126- ),
127-
128- 'eva_giant_patch14_224.clip_ft_in1k' : _cfg (
129- hf_hub_id = 'BAAI/EVA' , hf_hub_filename = 'eva_clip_vis_enc_sz224_ftcls_89p1.pt' ,
130- mean = OPENAI_CLIP_MEAN , std = OPENAI_CLIP_STD ,
131- ),
132- 'eva_giant_patch14_336.clip_ft_in1k' : _cfg (
133- hf_hub_id = 'BAAI/EVA' , hf_hub_filename = 'eva_clip_vis_enc_sz336_ftcls_89p4.pt' ,
134- mean = OPENAI_CLIP_MEAN , std = OPENAI_CLIP_STD ,
135- input_size = (3 , 336 , 336 )),
136- 'eva_giant_patch14_336.m30m_ft_in22k_in1k' : _cfg (
137- hf_hub_id = 'BAAI/EVA' , hf_hub_filename = 'eva_21k_1k_336px_psz14_ema_89p6.pt' ,
138- mean = IMAGENET_DEFAULT_MEAN , std = IMAGENET_DEFAULT_STD ,
139- input_size = (3 , 336 , 336 )),
140- 'eva_giant_patch14_560.m30m_ft_in22k_in1k' : _cfg (
141- hf_hub_id = 'BAAI/EVA' , hf_hub_filename = 'eva_21k_1k_560px_psz14_ema_89p7.pt' ,
142- mean = IMAGENET_DEFAULT_MEAN , std = IMAGENET_DEFAULT_STD ,
143- input_size = (3 , 560 , 560 )),
144- })
145-
146-
14771def gen_relative_position_index (window_size : Tuple [int , int ]) -> torch .Tensor :
14872 num_relative_distance = (2 * window_size [0 ] - 1 ) * (2 * window_size [1 ] - 1 ) + 3
14973 # cls to token & token 2 cls & cls to cls
@@ -416,6 +340,82 @@ def forward(self, x):
416340 return x
417341
418342
343+ def _cfg (url = '' , ** kwargs ):
344+ return {
345+ 'url' : url ,
346+ 'num_classes' : 1000 , 'input_size' : (3 , 224 , 224 ), 'pool_size' : None ,
347+ 'crop_pct' : .9 , 'interpolation' : 'bicubic' , 'fixed_input_size' : True ,
348+ 'mean' : (0.5 , 0.5 , 0.5 ), 'std' : (0.5 , 0.5 , 0.5 ),
349+ 'first_conv' : 'patch_embed.proj' , 'classifier' : 'head' ,
350+ ** kwargs
351+ }
352+
353+
354+ default_cfgs = generate_default_cfgs ({
355+ 'beit_base_patch16_224.in22k_ft_in22k_in1k' : _cfg (
356+ url = 'https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_base_patch16_224_pt22k_ft22kto1k.pth' ),
357+ 'beit_base_patch16_384.in22k_ft_in22k_in1k' : _cfg (
358+ url = 'https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_base_patch16_384_pt22k_ft22kto1k.pth' ,
359+ input_size = (3 , 384 , 384 ), crop_pct = 1.0 ,
360+ ),
361+ 'beit_base_patch16_224.in22k_ft_in22k' : _cfg (
362+ url = 'https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_base_patch16_224_pt22k_ft22k.pth' ,
363+ num_classes = 21841 ,
364+ ),
365+ 'beit_large_patch16_224.in22k_ft_in22k_in1k' : _cfg (
366+ url = 'https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_224_pt22k_ft22kto1k.pth' ),
367+ 'beit_large_patch16_384.in22k_ft_in22k_in1k' : _cfg (
368+ url = 'https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_384_pt22k_ft22kto1k.pth' ,
369+ input_size = (3 , 384 , 384 ), crop_pct = 1.0 ,
370+ ),
371+ 'beit_large_patch16_512.in22k_ft_in22k_in1k' : _cfg (
372+ url = 'https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_512_pt22k_ft22kto1k.pth' ,
373+ input_size = (3 , 512 , 512 ), crop_pct = 1.0 ,
374+ ),
375+ 'beit_large_patch16_224.in22k_ft_in22k' : _cfg (
376+ url = 'https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_224_pt22k_ft22k.pth' ,
377+ num_classes = 21841 ,
378+ ),
379+
380+ 'beitv2_base_patch16_224.in1k_ft_in22k_in1k' : _cfg (
381+ url = 'https://conversationhub.blob.core.windows.net/beit-share-public/beitv2/beitv2_base_patch16_224_pt1k_ft21kto1k.pth' ,
382+ mean = IMAGENET_DEFAULT_MEAN , std = IMAGENET_DEFAULT_STD
383+ ),
384+ 'beitv2_base_patch16_224.in1k_ft_in22k' : _cfg (
385+ url = 'https://conversationhub.blob.core.windows.net/beit-share-public/beitv2/beitv2_base_patch16_224_pt1k_ft21k.pth' ,
386+ num_classes = 21841 ,
387+ mean = IMAGENET_DEFAULT_MEAN , std = IMAGENET_DEFAULT_STD
388+ ),
389+ 'beitv2_large_patch16_224.in1k_ft_in22k_in1k' : _cfg (
390+ url = 'https://conversationhub.blob.core.windows.net/beit-share-public/beitv2/beitv2_large_patch16_224_pt1k_ft21kto1k.pth' ,
391+ crop_pct = 0.95 ,
392+ mean = IMAGENET_DEFAULT_MEAN , std = IMAGENET_DEFAULT_STD
393+ ),
394+ 'beitv2_large_patch16_224.in1k_ft_in22k' : _cfg (
395+ url = 'https://conversationhub.blob.core.windows.net/beit-share-public/beitv2/beitv2_large_patch16_224_pt1k_ft21k.pth' ,
396+ num_classes = 21841 ,
397+ mean = IMAGENET_DEFAULT_MEAN , std = IMAGENET_DEFAULT_STD
398+ ),
399+
400+ 'eva_giant_patch14_224.clip_ft_in1k' : _cfg (
401+ hf_hub_id = 'BAAI/EVA' , hf_hub_filename = 'eva_clip_vis_enc_sz224_ftcls_89p1.pt' ,
402+ mean = OPENAI_CLIP_MEAN , std = OPENAI_CLIP_STD , crop_pct = 1.0 ,
403+ ),
404+ 'eva_giant_patch14_336.clip_ft_in1k' : _cfg (
405+ hf_hub_id = 'BAAI/EVA' , hf_hub_filename = 'eva_clip_vis_enc_sz336_ftcls_89p4.pt' ,
406+ mean = OPENAI_CLIP_MEAN , std = OPENAI_CLIP_STD ,
407+ input_size = (3 , 336 , 336 ), crop_pct = 1.0 , crop_mode = 'squash' ),
408+ 'eva_giant_patch14_336.m30m_ft_in22k_in1k' : _cfg (
409+ hf_hub_id = 'BAAI/EVA' , hf_hub_filename = 'eva_21k_1k_336px_psz14_ema_89p6.pt' ,
410+ mean = IMAGENET_DEFAULT_MEAN , std = IMAGENET_DEFAULT_STD ,
411+ input_size = (3 , 336 , 336 ), crop_pct = 1.0 , crop_mode = 'squash' ),
412+ 'eva_giant_patch14_560.m30m_ft_in22k_in1k' : _cfg (
413+ hf_hub_id = 'BAAI/EVA' , hf_hub_filename = 'eva_21k_1k_560px_psz14_ema_89p7.pt' ,
414+ mean = IMAGENET_DEFAULT_MEAN , std = IMAGENET_DEFAULT_STD ,
415+ input_size = (3 , 560 , 560 ), crop_pct = 1.0 , crop_mode = 'squash' ),
416+ })
417+
418+
419419def _beit_checkpoint_filter_fn (state_dict , model ):
420420 if 'module' in state_dict :
421421 # beit v2 didn't strip module
@@ -425,7 +425,7 @@ def _beit_checkpoint_filter_fn(state_dict, model):
425425
426426def _create_beit (variant , pretrained = False , ** kwargs ):
427427 if kwargs .get ('features_only' , None ):
428- raise RuntimeError ('features_only not implemented for Beit models.' )
428+ raise RuntimeError ('features_only not implemented for BEiT models.' )
429429
430430 model = build_model_with_cfg (
431431 Beit , variant , pretrained ,
@@ -453,15 +453,6 @@ def beit_base_patch16_384(pretrained=False, **kwargs):
453453 return model
454454
455455
456- @register_model
457- def beit_base_patch16_224_in22k (pretrained = False , ** kwargs ):
458- model_kwargs = dict (
459- patch_size = 16 , embed_dim = 768 , depth = 12 , num_heads = 12 ,
460- use_abs_pos_emb = False , use_rel_pos_bias = True , init_values = 0.1 , ** kwargs )
461- model = _create_beit ('beit_base_patch16_224_in22k' , pretrained = pretrained , ** model_kwargs )
462- return model
463-
464-
465456@register_model
466457def beit_large_patch16_224 (pretrained = False , ** kwargs ):
467458 model_kwargs = dict (
@@ -489,15 +480,6 @@ def beit_large_patch16_512(pretrained=False, **kwargs):
489480 return model
490481
491482
492- @register_model
493- def beit_large_patch16_224_in22k (pretrained = False , ** kwargs ):
494- model_kwargs = dict (
495- patch_size = 16 , embed_dim = 1024 , depth = 24 , num_heads = 16 ,
496- use_abs_pos_emb = False , use_rel_pos_bias = True , init_values = 1e-5 , ** kwargs )
497- model = _create_beit ('beit_large_patch16_224_in22k' , pretrained = pretrained , ** model_kwargs )
498- return model
499-
500-
501483@register_model
502484def beitv2_base_patch16_224 (pretrained = False , ** kwargs ):
503485 model_kwargs = dict (
@@ -507,15 +489,6 @@ def beitv2_base_patch16_224(pretrained=False, **kwargs):
507489 return model
508490
509491
510- @register_model
511- def beitv2_base_patch16_224_in22k (pretrained = False , ** kwargs ):
512- model_kwargs = dict (
513- patch_size = 16 , embed_dim = 768 , depth = 12 , num_heads = 12 , mlp_ratio = 4 ,
514- use_abs_pos_emb = False , use_rel_pos_bias = True , init_values = 1e-5 , ** kwargs )
515- model = _create_beit ('beitv2_base_patch16_224_in22k' , pretrained = pretrained , ** model_kwargs )
516- return model
517-
518-
519492@register_model
520493def beitv2_large_patch16_224 (pretrained = False , ** kwargs ):
521494 model_kwargs = dict (
@@ -525,15 +498,6 @@ def beitv2_large_patch16_224(pretrained=False, **kwargs):
525498 return model
526499
527500
528- @register_model
529- def beitv2_large_patch16_224_in22k (pretrained = False , ** kwargs ):
530- model_kwargs = dict (
531- patch_size = 16 , embed_dim = 1024 , depth = 24 , num_heads = 16 ,
532- use_abs_pos_emb = False , use_rel_pos_bias = True , init_values = 1e-5 , ** kwargs )
533- model = _create_beit ('beitv2_large_patch16_224_in22k' , pretrained = pretrained , ** model_kwargs )
534- return model
535-
536-
537501@register_model
538502def eva_giant_patch14_224 (pretrained = False , ** kwargs ):
539503 """ EVA-g model https://arxiv.org/abs/2211.07636 """
0 commit comments