44import math
55import os
66from functools import partial
7+
78import torch
89import torch .nn as nn
910import torch .nn .functional as F
11+
12+
1013try :
1114 from timm .models .layers import drop_path , to_2tuple , trunc_normal_
1215except :
1316 from timm .layers import drop_path , to_2tuple , trunc_normal_
14-
17+
18+ from .rope import VisionRotaryEmbeddingFast
1519from .transformer import PatchDropout
16- from . rope import VisionRotaryEmbedding , VisionRotaryEmbeddingFast
20+
1721
1822if os .getenv ('ENV_TYPE' ) == 'deepspeed' :
1923 try :
2428 from torch .utils .checkpoint import checkpoint
2529
2630try :
27- import xformers
2831 import xformers .ops as xops
2932 XFORMERS_IS_AVAILBLE = True
3033except :
@@ -39,19 +42,19 @@ def __init__(self, drop_prob=None):
3942
4043 def forward (self , x ):
4144 return drop_path (x , self .drop_prob , self .training )
42-
45+
4346 def extra_repr (self ) -> str :
4447 return 'p={}' .format (self .drop_prob )
4548
4649
4750class Mlp (nn .Module ):
4851 def __init__ (
49- self ,
50- in_features ,
51- hidden_features = None ,
52- out_features = None ,
53- act_layer = nn .GELU ,
54- norm_layer = nn .LayerNorm ,
52+ self ,
53+ in_features ,
54+ hidden_features = None ,
55+ out_features = None ,
56+ act_layer = nn .GELU ,
57+ norm_layer = nn .LayerNorm ,
5558 drop = 0. ,
5659 subln = False ,
5760
@@ -71,15 +74,15 @@ def forward(self, x):
7174 x = self .fc1 (x )
7275 x = self .act (x )
7376 # x = self.drop(x)
74- # commit this for the orignal BERT implement
77+ # commit this for the orignal BERT implement
7578 x = self .ffn_ln (x )
7679
7780 x = self .fc2 (x )
7881 x = self .drop (x )
7982 return x
8083
8184class SwiGLU (nn .Module ):
82- def __init__ (self , in_features , hidden_features = None , out_features = None , act_layer = nn .SiLU , drop = 0. ,
85+ def __init__ (self , in_features , hidden_features = None , out_features = None , act_layer = nn .SiLU , drop = 0. ,
8386 norm_layer = nn .LayerNorm , subln = False ):
8487 super ().__init__ ()
8588 out_features = out_features or in_features
@@ -91,7 +94,7 @@ def __init__(self, in_features, hidden_features=None, out_features=None, act_lay
9194 self .act = act_layer ()
9295 self .ffn_ln = norm_layer (hidden_features ) if subln else nn .Identity ()
9396 self .w3 = nn .Linear (hidden_features , out_features )
94-
97+
9598 self .drop = nn .Dropout (drop )
9699
97100 def forward (self , x ):
@@ -172,20 +175,20 @@ def __init__(
172175
173176 def forward (self , x , rel_pos_bias = None , attn_mask = None ):
174177 B , N , C = x .shape
175- if self .subln :
178+ if self .subln :
176179 q = F .linear (input = x , weight = self .q_proj .weight , bias = self .q_bias )
177180 k = F .linear (input = x , weight = self .k_proj .weight , bias = None )
178181 v = F .linear (input = x , weight = self .v_proj .weight , bias = self .v_bias )
179182
180183 q = q .reshape (B , N , self .num_heads , - 1 ).permute (0 , 2 , 1 , 3 ) # B, num_heads, N, C
181- k = k .reshape (B , N , self .num_heads , - 1 ).permute (0 , 2 , 1 , 3 )
182- v = v .reshape (B , N , self .num_heads , - 1 ).permute (0 , 2 , 1 , 3 )
183- else :
184+ k = k .reshape (B , N , self .num_heads , - 1 ).permute (0 , 2 , 1 , 3 )
185+ v = v .reshape (B , N , self .num_heads , - 1 ).permute (0 , 2 , 1 , 3 )
186+ else :
184187
185188 qkv_bias = None
186189 if self .q_bias is not None :
187190 qkv_bias = torch .cat ((self .q_bias , torch .zeros_like (self .v_bias , requires_grad = False ), self .v_bias ))
188-
191+
189192 qkv = F .linear (input = x , weight = self .qkv .weight , bias = qkv_bias )
190193 qkv = qkv .reshape (B , N , 3 , self .num_heads , - 1 ).permute (2 , 0 , 3 , 1 , 4 ) # 3, B, num_heads, N, C
191194 q , k , v = qkv [0 ], qkv [1 ], qkv [2 ]
@@ -232,7 +235,7 @@ def forward(self, x, rel_pos_bias=None, attn_mask=None):
232235 if attn_mask is not None :
233236 attn_mask = attn_mask .bool ()
234237 attn = attn .masked_fill (~ attn_mask [:, None , None , :], float ("-inf" ))
235-
238+
236239 attn = attn .softmax (dim = - 1 )
237240 attn = self .attn_drop (attn )
238241
@@ -262,15 +265,15 @@ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None,
262265
263266 if naiveswiglu :
264267 self .mlp = SwiGLU (
265- in_features = dim ,
266- hidden_features = mlp_hidden_dim ,
268+ in_features = dim ,
269+ hidden_features = mlp_hidden_dim ,
267270 subln = subln ,
268271 norm_layer = norm_layer ,
269272 )
270273 else :
271274 self .mlp = Mlp (
272- in_features = dim ,
273- hidden_features = mlp_hidden_dim ,
275+ in_features = dim ,
276+ hidden_features = mlp_hidden_dim ,
274277 act_layer = act_layer ,
275278 subln = subln ,
276279 drop = drop
@@ -407,7 +410,7 @@ def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, em
407410 ft_seq_len = hw_seq_len if intp_freq else None ,
408411 # patch_dropout=patch_dropout
409412 )
410- else :
413+ else :
411414 self .rope = None
412415
413416 self .naiveswiglu = naiveswiglu
@@ -469,7 +472,7 @@ def _init_weights(self, m):
469472
470473 def get_num_layers (self ):
471474 return len (self .blocks )
472-
475+
473476 def lock (self , unlocked_groups = 0 , freeze_bn_stats = False ):
474477 assert unlocked_groups == 0 , 'partial locking not currently supported for this model'
475478 for param in self .parameters ():
@@ -491,7 +494,7 @@ def reset_classifier(self, num_classes, global_pool=''):
491494 self .head = nn .Linear (self .embed_dim , num_classes ) if num_classes > 0 else nn .Identity ()
492495
493496 def forward_features (self , x , return_all_features = False , return_hidden = False , shuffle = False ):
494-
497+
495498 x = self .patch_embed (x )
496499 batch_size , seq_len , _ = x .size ()
497500
0 commit comments