@@ -41,11 +41,11 @@ class RotaryEmbedding(Module):
4141 def __init__ (
4242 self ,
4343 dim ,
44- freqs_for : Union [Literal ["lang" ], Literal ["pixel" ], Literal ["constant" ]] = "lang" ,
44+ freqs_for : Union [Literal ["lang" ], Literal ["pixel" ], Literal ["constant" ]] = "lang" ,
4545 theta = 10000 ,
4646 max_freq = 10 ,
4747 num_freqs = 1 ,
48- learned_freq = False ,
48+ learned_freq = False ,
4949 theta_rescale_factor = 1.0 ,
5050 ):
5151 super ().__init__ ()
@@ -73,7 +73,6 @@ def forward(self, t: Tensor):
7373 return freqs
7474
7575
76-
7776@register_notrace_module
7877class Rope2D (Module ):
7978 def __init__ (self , dim , grid_size , use_cls_token = False ):
@@ -83,10 +82,10 @@ def __init__(self, dim, grid_size, use_cls_token=False):
8382 self .grid_size = grid_size
8483 self .rope = RotaryEmbedding (self .dim // 2 )
8584 self .init_tensors ()
86-
85+
8786 def init_tensors (self ):
8887 self .update_grid (self .grid_size [0 ], self .grid_size [1 ])
89-
88+
9089 def update_grid (self , grid_h , grid_w ):
9190 if self .use_cls_token :
9291 # +1 to leave space for the cls token to be (0, 0)
@@ -100,22 +99,22 @@ def update_grid(self, grid_h, grid_w):
10099 freq = torch .cat ([freqs_x , freqs_y ], dim = - 1 ).reshape (grid_h * grid_w , - 1 )
101100
102101 if self .use_cls_token :
103- freq = torch .cat ([freq , torch .zeros (1 , freq .shape [- 1 ])], dim = 0 )
102+ freq = torch .cat ([torch .zeros (1 , freq .shape [- 1 ]), freq ], dim = 0 )
104103 self .register_buffer ('freq' , freq [None , ...], persistent = False )
105104
106105 def rotate_half (self , x ):
107- shape = x .shape
106+ shape = x .shape
108107 x = x .view (shape [:- 1 ] + (- 1 , 2 ))
109108 x1 , x2 = x [..., 0 ], x [..., 1 ]
110109 x = torch .stack ((- x2 , x1 ), dim = - 1 )
111110 return x .view (shape [:- 1 ] + (- 1 ,))
112-
111+
113112 def apply_rotary_emb (self , freqs , t ):
114113 start_index = 0
115114 scale = 1.0
116115 seq_dim = - 2
117116 dtype = t .dtype
118-
117+
119118 # if len(t.shape) == 3:
120119 # seq_len = t.shape[seq_dim]
121120 # freqs = freqs[-seq_len:]
@@ -185,6 +184,7 @@ class SelfAttention(nn.Module):
185184 r"""
186185 Implements sequence packed attention and RoPe
187186 """
187+
188188 fused_attn : Final [bool ]
189189
190190 def __init__ (
@@ -214,11 +214,11 @@ def init_tensors(self):
214214 constant_ (self .in_proj_bias , 0.0 )
215215 constant_ (self .out_proj .bias , 0.0 )
216216
217-
218- def forward ( self ,
219- x : torch .Tensor ,
220- attn_mask : Optional [torch .Tensor ] = None ,
221- ):
217+ def forward (
218+ self ,
219+ x : torch .Tensor ,
220+ attn_mask : Optional [torch .Tensor ] = None ,
221+ ):
222222 batch , seq , embed_dim = x .shape
223223 proj = F .linear (x , self .in_proj_weight , self .in_proj_bias )
224224
@@ -235,7 +235,9 @@ def forward(self,
235235 q , k = self .rope (q , k )
236236
237237 if self .fused_attn :
238- attn = F .scaled_dot_product_attention (q , k , v , attn_mask = None , dropout_p = 0.0 , is_causal = False , scale = self .scale )
238+ attn = F .scaled_dot_product_attention (
239+ q , k , v , attn_mask = None , dropout_p = 0.0 , is_causal = False , scale = self .scale
240+ )
239241 else :
240242 q = q * self .scale
241243 attn = q @ k .transpose (- 2 , - 1 )
@@ -247,8 +249,6 @@ def forward(self,
247249 return F .linear (attn , self .out_proj .weight , self .out_proj .bias )
248250
249251
250-
251-
252252class ResidualAttentionBlock (nn .Module ):
253253 def __init__ (
254254 self ,
@@ -285,11 +285,7 @@ def __init__(
285285 )
286286 )
287287
288- def _call_attn (
289- self ,
290- q_x : torch .Tensor ,
291- attn_mask : Optional [torch .Tensor ] = None
292- ):
288+ def _call_attn (self , q_x : torch .Tensor , attn_mask : Optional [torch .Tensor ] = None ):
293289 if attn_mask is not None :
294290 # Leave boolean masks as is
295291 if not attn_mask .dtype == torch .bool :
@@ -300,7 +296,7 @@ def _call_attn(
300296 def forward (
301297 self ,
302298 x : torch .Tensor ,
303- attn_mask : Optional [torch .Tensor ] = None ,
299+ attn_mask : Optional [torch .Tensor ] = None ,
304300 ):
305301 x = x + self .drop_path1 (self .ls_1 (self ._call_attn (self .ln_1 (x ), attn_mask = attn_mask )))
306302 x = x + self .drop_path2 (self .ls_2 (self .mlp (self .ln_2 (x ))))
@@ -354,18 +350,14 @@ def truncate(self, layer_idx: int):
354350 def forward (
355351 self ,
356352 x : torch .Tensor ,
357- attn_mask : Optional [torch .Tensor ] = None ,
358- # layer_idx=-1, #: int = -1, # torchscript emits iterations over modules as unrolled loops. so dynamic layer_idx is not supported as in orig pe
353+ attn_mask : Optional [torch .Tensor ] = None ,
359354 ):
360- #stop_idx = (self.layers + layer_idx) % self.layers
361355 for i , r in enumerate (self .resblocks ):
362356 if self .grad_checkpointing and not torch .jit .is_scripting ():
363357 # TODO: handle kwargs https://github.com/pytorch/pytorch/issues/79887#issuecomment-1161758372
364358 x = checkpoint (r , x , None , None , attn_mask )
365359 else :
366360 x = r (x , attn_mask = attn_mask )
367- # if i == stop_idx:
368- # break
369361 return x
370362
371363
@@ -389,11 +381,11 @@ def __init__(
389381 use_cls_token : bool = False ,
390382 use_proj : bool = True ,
391383 output_dim : Optional [int ] = 1280 ,
392- num_classes : int = 0 ,
384+ num_classes : int = 0 ,
393385 attn_pooler_heads : int = 8 ,
394386 use_attn_pool : bool = True ,
395387 in_chans : int = 3 ,
396- drop_rate : float = 0. , # Expected to be here, TODO add a final drop layer once head finalized
388+ drop_rate : float = 0.0 , # Expected to be here, TODO add a final drop layer once head finalized
397389 ):
398390 super ().__init__ ()
399391 self .patch_size = patch_size
@@ -404,7 +396,7 @@ def __init__(
404396 self .num_classes = num_classes
405397 self .drop_rate = drop_rate
406398 self .emb_dim = width
407-
399+
408400 # PE contains an (optional) projection layer
409401 # Flow: x -> Transfomer(x) -> pool -> proj -> head (for timm).
410402 # forward_features: x -> Transfomer(x)
@@ -414,10 +406,10 @@ def __init__(
414406 if self .use_proj :
415407 self .proj_dim = output_dim
416408 self .head_hidden_size = self .proj_dim
417- self .num_features = width # self.proj_dim
409+ self .num_features = width # self.proj_dim
418410 else :
419- self .proj_dim = 0
420- assert output_dim == width
411+ self .proj_dim = 0
412+ assert output_dim == width
421413 self .head_hidden_size = width
422414 self .num_features = width
423415
@@ -445,7 +437,7 @@ def __init__(
445437 Rope2D (
446438 dim = width // heads ,
447439 use_cls_token = self .use_cls_token ,
448- grid_size = (img_size // patch_size , img_size // patch_size ),
440+ grid_size = (img_size // patch_size , img_size // patch_size ),
449441 )
450442 if self .use_rope2d
451443 else None
@@ -466,8 +458,7 @@ def __init__(
466458 rope = self .rope ,
467459 )
468460
469- self .feature_info = [
470- dict (module = f'blocks.{ i } ' , num_chs = width , reduction = patch_size ) for i in range (layers )]
461+ self .feature_info = [dict (module = f'blocks.{ i } ' , num_chs = width , reduction = patch_size ) for i in range (layers )]
471462
472463 if use_attn_pool :
473464 self .attn_pool = AttentionPooling (
@@ -479,7 +470,7 @@ def __init__(
479470 else :
480471 self .attn_pool = None
481472
482- self .head_act_layer = None # =act_layer if to add an additional activation between fc1(proj) and fc2(head)
473+ self .head_act_layer = None # =act_layer if to add an additional activation between fc1(proj) and fc2(head)
483474 self .init_tensors ()
484475
485476 def init_tensors (self ):
@@ -511,11 +502,11 @@ def init_submodule_tensors(module):
511502 # PE's: Transfomer(x) -> pool -> proj -> head (for timm). (PE contains an additional projection layer)
512503 if self .use_proj :
513504 self .proj = nn .Parameter (init_scale * torch .randn (self .width , self .proj_dim ))
514- else : # no projection (eg PE-lang and PE-spatial)
505+ else : # no projection (eg PE-lang and PE-spatial)
515506 self .proj = None
516507
517508 if self .num_classes > 0 :
518- self .head = nn .Linear (self .head_hidden_size , self .num_classes ) # no proj. input dim = self.width (pooled)
509+ self .head = nn .Linear (self .head_hidden_size , self .num_classes ) # no proj. input dim = self.width (pooled)
519510 else :
520511 self .head = nn .Identity ()
521512
@@ -536,8 +527,8 @@ def forward_pool_and_proj(self, x: torch.Tensor):
536527 return x
537528
538529 def forward_head (self , x : torch .Tensor , pre_logits : bool = False ):
539- # PE has an additional proj layer: Transfomer(x) -> pool -> proj -> head (for timm).
540- # To discuss with Ross where to split
530+ # PE has an additional proj layer: Transfomer(x) -> pool -> proj -> head (for timm).
531+ # To discuss with Ross where to split
541532 x = self .forward_pool_and_proj (x )
542533 if self .head_act_layer is not None :
543534 x = self .head_act_layer (x )
@@ -554,7 +545,7 @@ def forward_features(self, x: torch.Tensor, norm: bool = False):
554545 [self .class_embedding .view (1 , 1 , - 1 ).expand (batch , - 1 , - 1 ), x ],
555546 dim = 1 ,
556547 )
557-
548+
558549 if self .positional_embedding is not None :
559550 x = x + self .positional_embedding [None , ...]
560551
@@ -575,22 +566,22 @@ def reset_classifier(self, num_classes: int):
575566 if num_classes > 0 :
576567 if self .proj is not None :
577568 self .head = nn .Parameter (self .proj_dim , num_classes )
578- else : # no projection (eg PE-lang and PE-spatial)
569+ else : # no projection (eg PE-lang and PE-spatial)
579570 self .head = nn .Parameter (self .width , num_classes )
580571 else :
581572 self .head = nn .Identity ()
582573
583574 def forward_intermediates (
584- self ,
585- x : torch .Tensor ,
586- indices : Optional [Union [int , List [int ]]] = None ,
587- return_prefix_tokens : bool = False ,
588- norm : bool = False ,
589- stop_early : bool = False ,
590- output_fmt : str = 'NCHW' ,
591- intermediates_only : bool = False ,
575+ self ,
576+ x : torch .Tensor ,
577+ indices : Optional [Union [int , List [int ]]] = None ,
578+ return_prefix_tokens : bool = False ,
579+ norm : bool = False ,
580+ stop_early : bool = False ,
581+ output_fmt : str = 'NCHW' ,
582+ intermediates_only : bool = False ,
592583 ) -> Union [List [torch .Tensor ], Tuple [torch .Tensor , List [torch .Tensor ]]]:
593- """ Forward features that returns intermediates.
584+ """Forward features that returns intermediates.
594585
595586 Args:
596587 x: Input image tensor
@@ -612,7 +603,7 @@ def forward_intermediates(
612603 B , _ , height , width = x .shape
613604 # patch embedgging
614605 x = self .conv1 (x )
615- x = x .permute (0 , 2 , 3 , 1 ).reshape (B , - 1 , self .width ) # NLC
606+ x = x .permute (0 , 2 , 3 , 1 ).reshape (B , - 1 , self .width ) # NLC
616607
617608 if self .class_embedding is not None :
618609 x = torch .cat (
@@ -628,7 +619,7 @@ def forward_intermediates(
628619 if torch .jit .is_scripting () or not stop_early : # can't slice blocks in torchscript
629620 blocks = self .transformer .resblocks
630621 else :
631- blocks = self .transformer .resblocks [:max_index + 1 ]
622+ blocks = self .transformer .resblocks [: max_index + 1 ]
632623
633624 for i , blk in enumerate (blocks ):
634625 x = blk (x )
@@ -638,7 +629,7 @@ def forward_intermediates(
638629
639630 # process intermediates
640631 if self .class_embedding is not None :
641- prefix_tokens = [y [:, 0 ] for y in intermediates ] # only one cls token in PE
632+ prefix_tokens = [y [:, 0 ] for y in intermediates ] # only one cls token in PE
642633 intermediates = [y [:, 1 :] for y in intermediates ]
643634 else :
644635 prefix_tokens = None
@@ -657,7 +648,6 @@ def forward_intermediates(
657648 x = self .ln_post (x )
658649
659650 return x , intermediates
660-
661651
662652
663653def checkpoint_filter_fn (
@@ -677,18 +667,20 @@ def _cfg(url='', **kwargs):
677667 'num_classes' : 0 ,
678668 'interpolation' : 'bilinear' ,
679669 'fixed_input_size' : True ,
680- 'mean' : IMAGENET_INCEPTION_MEAN , # (0.5, 0.5, 0.5)
681- 'std' : IMAGENET_INCEPTION_STD , # (0.5, 0.5, 0.5)
670+ 'mean' : IMAGENET_INCEPTION_MEAN , # (0.5, 0.5, 0.5)
671+ 'std' : IMAGENET_INCEPTION_STD , # (0.5, 0.5, 0.5)
682672 'first_conv' : 'conv1' ,
683- 'classifier' : 'head' ,
673+ 'classifier' : 'head' ,
684674 ** kwargs ,
685675 }
686676
687677
688678default_cfgs = generate_default_cfgs (
689679 {
690680 # TODO finalize locations
691- 'vit_pe_core_base_patch16_224' : _cfg (hf_hub_id = 'facebook/pe_core_base_patch16_224_timm' , input_size = (3 , 224 , 224 )),
681+ 'vit_pe_core_base_patch16_224' : _cfg (
682+ hf_hub_id = 'facebook/pe_core_base_patch16_224_timm' , input_size = (3 , 224 , 224 )
683+ ),
692684 'vit_pe_core_large_patch14_336' : _cfg (hf_hub_id = 'timm/' , input_size = (3 , 336 , 336 )),
693685 'vit_pe_core_gigantic_patch14_448' : _cfg (hf_hub_id = 'timm/' , input_size = (3 , 448 , 448 )),
694686 'vit_pe_lang_large_patch14_448' : _cfg (hf_hub_id = 'timm/' , input_size = (3 , 448 , 448 )),
@@ -822,4 +814,4 @@ def vit_pe_spatial_gigantic_patch14_448(pretrained=False, **kwargs):
822814 ls_init_value = 0.1 ,
823815 use_proj = False ,
824816 )
825- return _create_pe ('vit_pe_spatial_gigantic_patch14_448' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
817+ return _create_pe ('vit_pe_spatial_gigantic_patch14_448' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
0 commit comments