1818 trunc_normal_ , lecun_normal_ , resample_patch_embed , resample_abs_pos_embed , use_fused_attn , \
1919 get_act_layer , get_norm_layer , LayerType , LayerScale
2020#from timm.layers import RotaryEmbeddingCat, RotaryEmbedding # not compatible
21+ from timm .data import IMAGENET_INCEPTION_MEAN , IMAGENET_INCEPTION_STD
2122
2223from ._builder import build_model_with_cfg
2324from ._features import feature_take_indices
24- from ._manipulate import named_apply , checkpoint_seq , adapt_input_conv
2525from ._registry import generate_default_cfgs , register_model , register_model_deprecations
2626
2727
28+ __all__ = ['PE' ]
29+
30+
31+ ####### PE's Rope ########
32+
2833def exists (val ):
2934 return val is not None
3035
31-
3236def default (val , d ):
3337 return val if exists (val ) else d
3438
35-
3639def rotate_half (x ):
3740 x = rearrange (x , "... (d r) -> ... d r" , r = 2 )
3841 x1 , x2 = x .unbind (dim = - 1 )
3942 x = torch .stack ((- x2 , x1 ), dim = - 1 )
4043 return rearrange (x , "... d r -> ... (d r)" )
4144
42-
4345@autocast ("cuda" , enabled = False )
4446def apply_rotary_emb (freqs , t , start_index = 0 , scale = 1.0 , seq_dim = - 2 ):
4547 dtype = t .dtype
@@ -330,6 +332,7 @@ def __call__(self, q, k):
330332
331333 return q , k
332334
335+ ####### PE's Modules ########
333336
334337class AttentionPooling (nn .Module ):
335338 def __init__ (
@@ -801,6 +804,41 @@ def checkpoint_filter_fn(
801804 state_dict = {k .replace ("visual." , "" ): v for k , v in state_dict .items () if "visual" in k }
802805 return state_dict
803806
807+
808+ default_cfgs = generate_default_cfgs ({
809+ 'pe_core_b16_224' : _cfg (
810+ hf_hub_id = 'timm/' ,
811+ license = 'apache-2.0' ,
812+ mean = IMAGENET_INCEPTION_MEAN , std = IMAGENET_INCEPTION_STD , num_classes = 0 ,
813+ input_size = (3 , 224 , 224 )),
814+ 'pe_core_l14_336' : _cfg (
815+ hf_hub_id = 'timm/' ,
816+ license = 'apache-2.0' ,
817+ mean = IMAGENET_INCEPTION_MEAN , std = IMAGENET_INCEPTION_STD , num_classes = 0 ,
818+ input_size = (3 , 336 , 336 )),
819+ 'pe_core_G14_448' : _cfg (
820+ hf_hub_id = 'timm/' ,
821+ license = 'apache-2.0' ,
822+ mean = IMAGENET_INCEPTION_MEAN , std = IMAGENET_INCEPTION_STD , num_classes = 0 ,
823+ input_size = (3 , 448 , 448 )),
824+ 'pe_lang_l14_448' : _cfg (
825+ hf_hub_id = 'timm/' ,
826+ license = 'apache-2.0' ,
827+ mean = IMAGENET_INCEPTION_MEAN , std = IMAGENET_INCEPTION_STD , num_classes = 0 ,
828+ input_size = (3 , 448 , 448 )),
829+ 'pe_lang_G14_448' : _cfg (
830+ hf_hub_id = 'timm/' ,
831+ license = 'apache-2.0' ,
832+ mean = IMAGENET_INCEPTION_MEAN , std = IMAGENET_INCEPTION_STD , num_classes = 0 ,
833+ input_size = (3 , 448 , 448 )),
834+ 'pe_spatial_G14_448' : _cfg (
835+ hf_hub_id = 'timm/' ,
836+ license = 'apache-2.0' ,
837+ mean = IMAGENET_INCEPTION_MEAN , std = IMAGENET_INCEPTION_STD , num_classes = 0 ,
838+ input_size = (3 , 448 , 448 )),
839+ })
840+
841+
804842def _create_pe (variant : str , pretrained : bool = False , ** kwargs ) -> PE :
805843 out_indices = kwargs .pop ('out_indices' , 3 )
806844
@@ -814,7 +852,6 @@ def _create_pe(variant: str, pretrained: bool = False, **kwargs) -> PE:
814852 ** kwargs ,
815853 )
816854
817-
818855@register_model
819856def pe_core_b16_224 (pretrained = False , ** kwargs ):
820857 model_args = dict (
0 commit comments