Skip to content

Commit e6bbf9f

Browse files
committed
add default config
1 parent 8eafe2c commit e6bbf9f

File tree

1 file changed

+42
-5
lines changed

1 file changed

+42
-5
lines changed

timm/models/pe.py

Lines changed: 42 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,28 +18,30 @@
1818
trunc_normal_, lecun_normal_, resample_patch_embed, resample_abs_pos_embed, use_fused_attn, \
1919
get_act_layer, get_norm_layer, LayerType, LayerScale
2020
#from timm.layers import RotaryEmbeddingCat, RotaryEmbedding # not compatible
21+
from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
2122

2223
from ._builder import build_model_with_cfg
2324
from ._features import feature_take_indices
24-
from ._manipulate import named_apply, checkpoint_seq, adapt_input_conv
2525
from ._registry import generate_default_cfgs, register_model, register_model_deprecations
2626

2727

28+
__all__ = ['PE']
29+
30+
31+
####### PE's Rope ########
32+
2833
def exists(val):
2934
return val is not None
3035

31-
3236
def default(val, d):
3337
return val if exists(val) else d
3438

35-
3639
def rotate_half(x):
3740
x = rearrange(x, "... (d r) -> ... d r", r=2)
3841
x1, x2 = x.unbind(dim=-1)
3942
x = torch.stack((-x2, x1), dim=-1)
4043
return rearrange(x, "... d r -> ... (d r)")
4144

42-
4345
@autocast("cuda", enabled=False)
4446
def apply_rotary_emb(freqs, t, start_index=0, scale=1.0, seq_dim=-2):
4547
dtype = t.dtype
@@ -330,6 +332,7 @@ def __call__(self, q, k):
330332

331333
return q, k
332334

335+
####### PE's Modules ########
333336

334337
class AttentionPooling(nn.Module):
335338
def __init__(
@@ -801,6 +804,41 @@ def checkpoint_filter_fn(
801804
state_dict = {k.replace("visual.", ""): v for k, v in state_dict.items() if "visual" in k}
802805
return state_dict
803806

807+
808+
default_cfgs = generate_default_cfgs({
809+
'pe_core_b16_224': _cfg(
810+
hf_hub_id='timm/',
811+
license='apache-2.0',
812+
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, num_classes=0,
813+
input_size=(3, 224, 224)),
814+
'pe_core_l14_336': _cfg(
815+
hf_hub_id='timm/',
816+
license='apache-2.0',
817+
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, num_classes=0,
818+
input_size=(3, 336, 336)),
819+
'pe_core_G14_448': _cfg(
820+
hf_hub_id='timm/',
821+
license='apache-2.0',
822+
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, num_classes=0,
823+
input_size=(3, 448, 448)),
824+
'pe_lang_l14_448': _cfg(
825+
hf_hub_id='timm/',
826+
license='apache-2.0',
827+
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, num_classes=0,
828+
input_size=(3, 448, 448)),
829+
'pe_lang_G14_448': _cfg(
830+
hf_hub_id='timm/',
831+
license='apache-2.0',
832+
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, num_classes=0,
833+
input_size=(3, 448, 448)),
834+
'pe_spatial_G14_448': _cfg(
835+
hf_hub_id='timm/',
836+
license='apache-2.0',
837+
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, num_classes=0,
838+
input_size=(3, 448, 448)),
839+
})
840+
841+
804842
def _create_pe(variant: str, pretrained: bool = False, **kwargs) -> PE:
805843
out_indices = kwargs.pop('out_indices', 3)
806844

@@ -814,7 +852,6 @@ def _create_pe(variant: str, pretrained: bool = False, **kwargs) -> PE:
814852
**kwargs,
815853
)
816854

817-
818855
@register_model
819856
def pe_core_b16_224(pretrained=False, **kwargs):
820857
model_args = dict(

0 commit comments

Comments
 (0)