1111from torch .nn .parameter import Parameter
1212from torch .amp import autocast
1313from torch .utils .checkpoint import checkpoint
14+ from torch .jit import Final
15+
1416
1517### Import timm layers
1618from timm .layers import (
1719 DropPath ,
1820 AttentionPoolLatent ,
1921 LayerType ,
2022 LayerScale ,
23+ use_fused_attn ,
2124)
2225
2326# from timm.layers import RotaryEmbeddingCat, RotaryEmbedding # not compatible
@@ -70,6 +73,7 @@ def forward(self, t: Tensor):
7073 return freqs
7174
7275
76+
7377@register_notrace_module
7478class Rope2D (Module ):
7579 def __init__ (self , dim , grid_size , use_cls_token = False ):
@@ -181,6 +185,8 @@ class SelfAttention(nn.Module):
181185 r"""
182186 Implements sequence packed attention and RoPe
183187 """
188+ fused_attn : Final [bool ]
189+
184190 def __init__ (
185191 self ,
186192 embed_dim : int ,
@@ -201,12 +207,14 @@ def __init__(
201207
202208 self .rope = rope
203209 self .scale = self .head_dim ** (- 0.5 )
210+ self .fused_attn = use_fused_attn ()
204211
205212 def init_tensors (self ):
206213 xavier_uniform_ (self .in_proj_weight )
207214 constant_ (self .in_proj_bias , 0.0 )
208215 constant_ (self .out_proj .bias , 0.0 )
209216
217+
210218 def forward (self ,
211219 x : torch .Tensor ,
212220 attn_mask : Optional [torch .Tensor ] = None ,
@@ -226,12 +234,21 @@ def forward(self,
226234 if self .rope is not None :
227235 q , k = self .rope (q , k )
228236
229- attn = F .scaled_dot_product_attention (q , k , v , attn_mask = None , dropout_p = 0.0 , is_causal = False , scale = self .scale )
237+ if self .fused_attn :
238+ attn = F .scaled_dot_product_attention (q , k , v , attn_mask = None , dropout_p = 0.0 , is_causal = False , scale = self .scale )
239+ else :
240+ q = q * self .scale
241+ attn = q @ k .transpose (- 2 , - 1 )
242+ attn = attn .softmax (dim = - 1 )
243+ attn = attn @ v
244+
230245 attn = attn .permute (0 , 2 , 1 , 3 ).contiguous ().view (batch , seq , - 1 )
231246
232247 return F .linear (attn , self .out_proj .weight , self .out_proj .bias )
233248
234249
250+
251+
235252class ResidualAttentionBlock (nn .Module ):
236253 def __init__ (
237254 self ,
@@ -246,10 +263,7 @@ def __init__(
246263 ):
247264 super ().__init__ ()
248265
249- #if rope:
250266 self .attn = SelfAttention (d_model , n_head , rope = rope )
251- #else:
252- # self.attn = nn.MultiheadAttention(d_model, n_head, batch_first=True)
253267
254268 self .ls_1 = LayerScale (d_model , ls_init_value ) if ls_init_value is not None else nn .Identity ()
255269 self .ls_2 = LayerScale (d_model , ls_init_value ) if ls_init_value is not None else nn .Identity ()
@@ -281,10 +295,7 @@ def _call_attn(
281295 if not attn_mask .dtype == torch .bool :
282296 attn_mask = attn_mask .to (q_x .dtype )
283297
284- #if isinstance(self.attn, SelfAttention):
285298 return self .attn (q_x , attn_mask = attn_mask )
286- #else:
287- # return self.attn(q_x, q_x, q_x, attn_mask=attn_mask, need_weights=False)[0]
288299
289300 def forward (
290301 self ,
@@ -392,6 +403,7 @@ def __init__(
392403 self .in_chans = in_chans
393404 self .num_classes = num_classes
394405 self .drop_rate = drop_rate
406+ self .emb_dim = width
395407
396408 # PE contains an (optional) projection layer
397409 # Flow: x -> Transfomer(x) -> pool -> proj -> head (for timm).
@@ -410,6 +422,7 @@ def __init__(
410422 self .num_features = width
411423
412424 self .num_classes = num_classes
425+ self .output_dim = output_dim
413426
414427 self .use_abs_posemb = use_abs_posemb
415428 self .use_cls_token = use_cls_token
@@ -466,6 +479,7 @@ def __init__(
466479 else :
467480 self .attn_pool = None
468481
482+ self .act_layer_cfg = act_layer
469483 self .init_tensors ()
470484
471485 def init_tensors (self ):
@@ -523,8 +537,10 @@ def forward_pool_and_proj(self, x: torch.Tensor):
523537
524538 def forward_head (self , x : torch .Tensor , pre_logits : bool = False ):
525539 # PE has an additional proj layer: Transfomer(x) -> pool -> proj -> head (for timm).
526- # Ideally pool To discuss with Ross where to split
540+ # To discuss with Ross where to split
527541 x = self .forward_pool_and_proj (x )
542+ if self .head_act_layer is not None :
543+ x = self .head_act_layer (x )
528544 return x if pre_logits else self .head (x )
529545
530546 def forward_features (self , x : torch .Tensor , norm : bool = False ):
@@ -806,5 +822,4 @@ def vit_pe_spatial_gigantic_patch14_448(pretrained=False, **kwargs):
806822 ls_init_value = 0.1 ,
807823 use_proj = False ,
808824 )
809- return _create_pe ('vit_pe_spatial_gigantic_patch14_448' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
810-
825+ return _create_pe ('vit_pe_spatial_gigantic_patch14_448' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
0 commit comments