Skip to content

Commit afe4375

Browse files
committed
update BEiT3
1 parent 0085149 commit afe4375

File tree

2 files changed

+7
-9
lines changed

2 files changed

+7
-9
lines changed

tests/test_models.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,13 +56,13 @@
5656
'regnet', 'byobnet', 'byoanet', 'mlp_mixer', 'hiera', 'fastvit', 'hieradet_sam2', 'aimv2*',
5757
'tiny_vit', 'vovnet', 'tresnet', 'rexnet', 'resnetv2', 'repghost', 'repvit', 'pvt_v2', 'nextvit', 'nest',
5858
'mambaout', 'inception_next', 'inception_v4', 'hgnet', 'gcvit', 'focalnet', 'efficientformer_v2', 'edgenext',
59-
'davit', 'rdnet', 'convnext', 'pit'
59+
'davit', 'rdnet', 'convnext', 'pit', 'beit3',
6060
]
6161

6262
# transformer / hybrid models don't support full set of spatial / feature APIs and/or have spatial output.
6363
NON_STD_FILTERS = [
6464
'vit_*', 'tnt_*', 'pit_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*',
65-
'convit_*', 'levit*', 'visformer*', 'deit*', 'xcit_*', 'crossvit_*', 'beit*', 'aimv2*',
65+
'convit_*', 'levit*', 'visformer*', 'deit*', 'xcit_*', 'crossvit_*', 'beit*', 'aimv2*', 'beit3*',
6666
'poolformer_*', 'volo_*', 'sequencer2d_*', 'mvitv2*', 'gcvit*', 'efficientformer*', 'sam_hiera*',
6767
'eva_*', 'flexivit*', 'eva02*', 'samvit_*', 'efficientvit_m*', 'tiny_vit_*', 'hiera_*', 'vitamin*', 'test_vit*',
6868
]

timm/models/beit3.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ def __init__(
8686
dim: int,
8787
num_heads: int,
8888
drop_rate: float = 0.,
89-
norm_layer: LayerType = partial(LayerNorm, eps=1e-5)
89+
norm_layer: LayerType = partial(LayerNorm, eps=1e-5),
9090
):
9191
super().__init__()
9292
self.num_heads = num_heads
@@ -122,7 +122,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
122122
attn_probs = self.attn_drop(attn_weights)
123123

124124
attn = torch.bmm(attn_probs, v)
125-
attn = attn.transpose(0, 1).reshape(N, B, C).transpose(0, 1)
125+
attn = attn.view(B, self.num_heads, N, self.head_dim).transpose(1, 2)
126+
attn = attn.reshape(B, N, C)
126127
attn = self.inner_attn_ln(attn)
127128
attn = self.out_proj(attn)
128129
return attn
@@ -403,7 +404,7 @@ def _cfg(url: str = '', **kwargs: Any) -> Dict[str, Any]:
403404
'paper_ids': 'arXiv:2208.10442',
404405
'paper_name': 'Image as a Foreign Language: BEiT Pretraining for Vision and Vision-Language Tasks',
405406
'origin_url': 'https://github.com/microsoft/unilm/tree/master/beit3',
406-
**kwargs
407+
**kwargs,
407408
}
408409

409410

@@ -427,10 +428,7 @@ def _cfg(url: str = '', **kwargs: Any) -> Dict[str, Any]:
427428
})
428429

429430

430-
def checkpoint_filter_fn(
431-
state_dict: Dict[str, torch.Tensor],
432-
model: BEiT3,
433-
) -> Dict[str, torch.Tensor]:
431+
def checkpoint_filter_fn(state_dict: Dict[str, torch.Tensor], model: BEiT3) -> Dict[str, torch.Tensor]:
434432
if 'model' in state_dict:
435433
state_dict = state_dict['model']
436434

0 commit comments

Comments
 (0)