Skip to content

Commit b5a814e

Browse files
committed
add giant model param
1 parent afe4375 commit b5a814e

File tree

1 file changed

+42
-16
lines changed

1 file changed

+42
-16
lines changed

timm/models/beit3.py

Lines changed: 42 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,14 @@
99
1010
@inproceedings{beit3,
1111
title={Image as a foreign language: {BEiT} pretraining for vision and vision-language tasks},
12-
author={Wenhui Wang and Hangbo Bao and Li Dong and Johan Bjorck and Zhiliang Peng and Qiang Liu and Kriti Aggarwal and Owais Khan Mohammed and Saksham Singhal and Subhojit Som and Furu Wei},
12+
author={Wenhui Wang and Hangbo Bao and Li Dong and Johan Bjorck and Zhiliang Peng and Qiang Liu and Kriti Aggarwal
13+
and Owais Khan Mohammed and Saksham Singhal and Subhojit Som and Furu Wei},
1314
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
1415
year={2023}
1516
}
1617
@InProceedings{Wang_2023_CVPR,
17-
author = {Wang, Wenhui and Bao, Hangbo and Dong, Li and Bjorck, Johan and Peng, Zhiliang and Liu, Qiang and Aggarwal, Kriti and Mohammed, Owais Khan and Singhal, Saksham and Som, Subhojit and Wei, Furu},
18+
author = {Wang, Wenhui and Bao, Hangbo and Dong, Li and Bjorck, Johan and Peng, Zhiliang and Liu, Qiang and Aggarwal,
19+
Kriti and Mohammed, Owais Khan and Singhal, Saksham and Som, Subhojit and Wei, Furu},
1820
title = {Image as a Foreign Language: BEiT Pretraining for Vision and Vision-Language Tasks},
1921
booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
2022
month = {June},
@@ -65,6 +67,7 @@ class PositionalEmbedding(nn.Embedding):
6567
https://github.com/microsoft/torchscale/blob/main/torchscale/component/embedding.py#L99-L119
6668
"""
6769
def forward(self, x: torch.Tensor) -> torch.Tensor:
70+
# being consistent with Fairseq, which starts from 2.
6871
return F.embedding(
6972
torch.arange(2, self.num_embeddings).long().unsqueeze(0).to(x.device),
7073
self.weight,
@@ -108,22 +111,23 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
108111
v = self.v_proj(x)
109112
q *= self.scaling
110113

114+
## (B, N, C) >> (B, N, num_heads, head_dim) >> (B, num_heads, N, head_dim)
111115
q = q.view(B, N, self.num_heads, self.head_dim).transpose(1, 2)
112116
k = k.view(B, N, self.num_heads, self.head_dim).transpose(1, 2)
113117
v = v.view(B, N, self.num_heads, self.head_dim).transpose(1, 2)
118+
119+
## (B, num_heads, N, head_dim) >> (B * num_heads, N, head_dim)
114120
q = q.reshape(B * self.num_heads, N, self.head_dim)
115121
k = k.reshape(B * self.num_heads, N, self.head_dim)
116122
v = v.reshape(B * self.num_heads, N, self.head_dim)
117123

118-
attn_weights = torch.bmm(q, k.transpose(1, 2))
119-
attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).type_as(
120-
attn_weights
121-
)
124+
attn_weights = torch.bmm(q, k.transpose(1, 2)) # (B * num_heads, N, N)
125+
attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).type_as(attn_weights)
122126
attn_probs = self.attn_drop(attn_weights)
127+
attn = torch.bmm(attn_probs, v) # (B * num_heads, N, head_dim)
123128

124-
attn = torch.bmm(attn_probs, v)
125-
attn = attn.view(B, self.num_heads, N, self.head_dim).transpose(1, 2)
126-
attn = attn.reshape(B, N, C)
129+
## (B * num_heads N, head_dim) >> (B, N, num_heads * head_dim) == (B, N, C)
130+
attn = attn.view(B, self.num_heads, N, self.head_dim).transpose(1, 2).reshape(B, N, C)
127131
attn = self.inner_attn_ln(attn)
128132
attn = self.out_proj(attn)
129133
return attn
@@ -409,26 +413,28 @@ def _cfg(url: str = '', **kwargs: Any) -> Dict[str, Any]:
409413

410414

411415
default_cfgs = generate_default_cfgs({
412-
'beit3_base_patch16_224.in1k': _cfg(
416+
'beit3_base_patch16_224.in22k_ft_in1k': _cfg(
413417
url='https://github.com/addf400/files/releases/download/beit3/beit3_base_patch16_224_in1k.pth',
414418
# hf_hub_id='timm/',
415419
),
416-
'beit3_base_patch16_224.indomain_in1k': _cfg(
420+
'beit3_base_patch16_224.in22k_indomain_ft_in1k': _cfg(
417421
url='https://github.com/addf400/files/releases/download/beit3/beit3_base_indomain_patch16_224_in1k.pth',
418422
# hf_hub_id='timm/',
419423
),
420-
'beit3_large_patch16_224.in1k': _cfg(
424+
'beit3_large_patch16_224.in22k_ft_in1k': _cfg(
421425
url='https://github.com/addf400/files/releases/download/beit3/beit3_large_patch16_224_in1k.pth',
422426
# hf_hub_id='timm/',
423427
),
424-
'beit3_large_patch16_224.indomain_in1k': _cfg(
428+
'beit3_large_patch16_224.in22k_indomain_ft_in1k': _cfg(
425429
url='https://github.com/addf400/files/releases/download/beit3/beit3_large_indomain_patch16_224_in1k.pth',
426430
# hf_hub_id='timm/',
427431
),
432+
'beit3_giant_patch14_224.untrained': _cfg(url=''),
433+
'beit3_giant_patch14_336.untrained': _cfg(url='', input_size=(3, 336, 336)),
428434
})
429435

430436

431-
def checkpoint_filter_fn(state_dict: Dict[str, torch.Tensor], model: BEiT3) -> Dict[str, torch.Tensor]:
437+
def checkpoint_filter_fn(state_dict: Dict[str, torch.Tensor], model: nn.Module) -> Dict[str, torch.Tensor]:
432438
if 'model' in state_dict:
433439
state_dict = state_dict['model']
434440

@@ -459,11 +465,11 @@ def checkpoint_filter_fn(state_dict: Dict[str, torch.Tensor], model: BEiT3) -> D
459465
k = k.replace('A.', '')
460466

461467
out_dict[k] = v
462-
468+
463469
return out_dict
464470

465471

466-
def _create_beit3(variant: str, pretrained: bool, **kwargs: Any) -> BEiT3:
472+
def _create_beit3(variant: str, pretrained: bool = False, **kwargs: Any) -> BEiT3:
467473
out_indices = kwargs.pop('out_indices', 3)
468474
model = build_model_with_cfg(
469475
BEiT3, variant, pretrained,
@@ -488,3 +494,23 @@ def beit3_large_patch16_224(pretrained: bool = False, **kwargs: Any) -> BEiT3:
488494
patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4)
489495
model = _create_beit3('beit3_large_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs))
490496
return model
497+
498+
499+
@register_model
500+
def beit3_giant_patch14_224(pretrained: bool = False, **kwargs: Any) -> BEiT3:
501+
## FFN inner hidden size = embed_dim * mlp_ratio
502+
## 6144 = int(1408 * 4.3637)
503+
model_args = dict(
504+
patch_size=14, embed_dim=1408, depth=40, num_heads=16, mlp_ratio=4.3637)
505+
model = _create_beit3('beit3_giant_patch14_224', pretrained=pretrained, **dict(model_args, **kwargs))
506+
return model
507+
508+
509+
@register_model
510+
def beit3_giant_patch14_336(pretrained: bool = False, **kwargs: Any) -> BEiT3:
511+
## FFN inner hidden size = embed_dim * mlp_ratio
512+
## 6144 = int(1408 * 4.3637)
513+
model_args = dict(
514+
img_size=336, patch_size=14, embed_dim=1408, depth=40, num_heads=16, mlp_ratio=4.3637)
515+
model = _create_beit3('beit3_giant_patch14_336', pretrained=pretrained, **dict(model_args, **kwargs))
516+
return model

0 commit comments

Comments
 (0)