@@ -288,7 +288,7 @@ def __init__(
288288 self .grn = GlobalResponseNorm (4 * dim , channels_last = True , ** dd )
289289 self .pwconv2 = nn .Linear (4 * dim , dim , ** dd )
290290 self .drop_path = DropPath (drop_path ) if drop_path > 0. else nn .Identity ()
291- self .attention = SpatialAttention (** dd )
291+ self .attn = SpatialAttention (** dd )
292292
293293 def forward (self , x : torch .Tensor ) -> torch .Tensor :
294294 shortcut = x
@@ -301,9 +301,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
301301 x = self .pwconv2 (x )
302302 x = x .permute (0 , 3 , 1 , 2 )
303303
304- attention = self .attention (x )
305- up_attn = F .interpolate (attention , size = x .shape [2 :], mode = 'bilinear' , align_corners = True )
306- x = x * up_attn
304+ attn = self .attn (x )
305+ attn = F .interpolate (attn , size = x .shape [2 :], mode = 'bilinear' , align_corners = True )
306+ x = x * attn
307307
308308 return shortcut + self .drop_path (x )
309309
@@ -371,15 +371,15 @@ def __init__(
371371 super ().__init__ ()
372372 self .avgpool = nn .AdaptiveAvgPool2d ((7 , 7 ))
373373 self .conv = nn .Conv2d (2 , 1 , kernel_size = 7 , padding = 3 , ** dd )
374- self .attention = SpatialTransformerBlock (** dd )
374+ self .attn = SpatialTransformerBlock (** dd )
375375
376376 def forward (self , x : torch .Tensor ) -> torch .Tensor :
377377 x_avg = x .mean (dim = 1 , keepdim = True )
378378 x_max = x .amax (dim = 1 , keepdim = True )
379379 x = torch .cat ([x_avg , x_max ], dim = 1 )
380380 x = self .avgpool (x )
381381 x = self .conv (x )
382- x = self .attention (x )
382+ x = self .attn (x )
383383 return x
384384
385385
@@ -395,6 +395,7 @@ def __init__(
395395 downsample : bool = False ,
396396 attn_drop : float = 0. ,
397397 proj_drop : float = 0. ,
398+ drop_path : float = 0. ,
398399 device = None ,
399400 dtype = None ,
400401 ) -> None :
@@ -423,9 +424,11 @@ def __init__(
423424 proj_drop = proj_drop ,
424425 ** dd ,
425426 )
427+ self .drop_path1 = DropPath (drop_path ) if drop_path > 0. else nn .Identity ()
426428
427429 self .norm2 = nn .LayerNorm (oup , ** dd )
428430 self .mlp = Mlp (oup , hidden_dim , oup , act_layer = nn .GELU , drop = proj_drop , ** dd )
431+ self .drop_path2 = DropPath (drop_path ) if drop_path > 0. else nn .Identity ()
429432
430433 def forward (self , x : torch .Tensor ) -> torch .Tensor :
431434 if self .downsample :
@@ -437,7 +440,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
437440 x_t = self .pos_embed (x_t , (H , W ))
438441 x_t = self .attn (x_t )
439442 x_t = x_t .transpose (1 , 2 ).reshape (B , - 1 , H , W )
440- x = shortcut + x_t
443+ x = shortcut + self . drop_path1 ( x_t )
441444 else :
442445 B , C , H , W = x .shape
443446 shortcut = x
@@ -446,15 +449,15 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
446449 x_t = self .pos_embed (x_t , (H , W ))
447450 x_t = self .attn (x_t )
448451 x_t = x_t .transpose (1 , 2 ).reshape (B , - 1 , H , W )
449- x = shortcut + x_t
452+ x = shortcut + self . drop_path1 ( x_t )
450453
451454 # MLP block
452455 B , C , H , W = x .shape
453456 shortcut = x
454457 x_t = x .flatten (2 ).transpose (1 , 2 )
455458 x_t = self .mlp (self .norm2 (x_t ))
456459 x_t = x_t .transpose (1 , 2 ).reshape (B , C , H , W )
457- x = shortcut + x_t
460+ x = shortcut + self . drop_path2 ( x_t )
458461
459462 return x
460463
@@ -491,69 +494,64 @@ def __init__(
491494 self ,
492495 num_classes : int = 1000 ,
493496 in_chans : int = 3 ,
497+ dims : Tuple [int , ...] = (32 , 72 , 168 , 386 ),
498+ depths : Tuple [int , ...] = (2 , 2 , 8 , 6 ),
499+ transformer_depths : Tuple [int , ...] = (0 , 0 , 2 , 2 ),
494500 drop_path_rate : float = 0.0 ,
501+ transformer_drop_path : bool = False ,
495502 global_pool : str = 'avg' ,
496503 device = None ,
497504 dtype = None ,
498505 ** kwargs ,
499506 ) -> None :
500507 dd = dict (device = device , dtype = dtype )
501508 super ().__init__ ()
509+ if in_chans != 3 :
510+ warnings .warn (
511+ f'CSATv2 is designed for 3-channel RGB input. '
512+ f'in_chans={ in_chans } may not work correctly with the DCT stem.'
513+ )
502514 self .num_classes = num_classes
503515 self .global_pool = global_pool
504516 self .grad_checkpointing = False
505517
506- dims = [32 , 72 , 168 , 386 ]
507518 self .num_features = dims [- 1 ]
508519 self .head_hidden_size = self .num_features
509520
510- self .feature_info = [
511- dict (num_chs = dims [0 ], reduction = 8 , module = 'stem_dct' ),
512- dict (num_chs = dims [0 ], reduction = 8 , module = 'stages.0' ),
513- dict (num_chs = dims [1 ], reduction = 16 , module = 'stages.1' ),
514- dict (num_chs = dims [2 ], reduction = 32 , module = 'stages.2' ),
515- dict (num_chs = dims [3 ], reduction = 64 , module = 'stages.3' ),
516- ]
517-
518- depths = [2 , 2 , 6 , 4 ]
519- dp_rates = [x .item () for x in torch .linspace (0 , drop_path_rate , sum (depths ))]
521+ # Build feature_info dynamically
522+ self .feature_info = [dict (num_chs = dims [0 ], reduction = 8 , module = 'stem_dct' )]
523+ reduction = 8
524+ for i , dim in enumerate (dims ):
525+ if i > 0 :
526+ reduction *= 2
527+ self .feature_info .append (dict (num_chs = dim , reduction = reduction , module = f'stages.{ i } ' ))
528+
529+ # Build drop path rates for all blocks (0 for transformer blocks when transformer_drop_path=False)
530+ total_blocks = sum (depths ) if transformer_drop_path else sum (d - t for d , t in zip (depths , transformer_depths ))
531+ dp_iter = iter (torch .linspace (0 , drop_path_rate , total_blocks ).tolist ())
532+ dp_rates = []
533+ for depth , t_depth in zip (depths , transformer_depths ):
534+ dp_rates += [next (dp_iter ) for _ in range (depth - t_depth )]
535+ dp_rates += [next (dp_iter ) if transformer_drop_path else 0. for _ in range (t_depth )]
520536
521537 self .stem_dct = LearnableDct2d (8 , ** dd )
522538
523- self .stages = nn .Sequential (
524- nn .Sequential (
525- Block (dim = dims [0 ], drop_path = dp_rates [0 ], ** dd ),
526- Block (dim = dims [0 ], drop_path = dp_rates [1 ], ** dd ),
527- LayerNorm2d (dims [0 ], eps = 1e-6 , ** dd ),
528- ),
529- nn .Sequential (
530- nn .Conv2d (dims [0 ], dims [1 ], kernel_size = 2 , stride = 2 , ** dd ),
531- Block (dim = dims [1 ], drop_path = dp_rates [2 ], ** dd ),
532- Block (dim = dims [1 ], drop_path = dp_rates [3 ], ** dd ),
533- LayerNorm2d (dims [1 ], eps = 1e-6 , ** dd ),
534- ),
535- nn .Sequential (
536- nn .Conv2d (dims [1 ], dims [2 ], kernel_size = 2 , stride = 2 , ** dd ),
537- Block (dim = dims [2 ], drop_path = dp_rates [4 ], ** dd ),
538- Block (dim = dims [2 ], drop_path = dp_rates [5 ], ** dd ),
539- Block (dim = dims [2 ], drop_path = dp_rates [6 ], ** dd ),
540- Block (dim = dims [2 ], drop_path = dp_rates [7 ], ** dd ),
541- Block (dim = dims [2 ], drop_path = dp_rates [8 ], ** dd ),
542- Block (dim = dims [2 ], drop_path = dp_rates [9 ], ** dd ),
543- TransformerBlock (inp = dims [2 ], oup = dims [2 ], ** dd ),
544- TransformerBlock (inp = dims [2 ], oup = dims [2 ], ** dd ),
545- LayerNorm2d (dims [2 ], eps = 1e-6 , ** dd ),
546- ),
547- nn .Sequential (
548- nn .Conv2d (dims [2 ], dims [3 ], kernel_size = 2 , stride = 2 , ** dd ),
549- Block (dim = dims [3 ], drop_path = dp_rates [10 ], ** dd ),
550- Block (dim = dims [3 ], drop_path = dp_rates [11 ], ** dd ),
551- Block (dim = dims [3 ], drop_path = dp_rates [12 ], ** dd ),
552- Block (dim = dims [3 ], drop_path = dp_rates [13 ], ** dd ),
553- TransformerBlock (inp = dims [3 ], oup = dims [3 ], ** dd ),
554- TransformerBlock (inp = dims [3 ], oup = dims [3 ], ** dd ),
555- ),
556- )
539+ # Build stages dynamically
540+ dp_iter = iter (dp_rates )
541+ stages = []
542+ for i , (dim , depth , t_depth ) in enumerate (zip (dims , depths , transformer_depths )):
543+ layers = (
544+ # Downsample at start of stage (except first stage)
545+ ([nn .Conv2d (dims [i - 1 ], dim , kernel_size = 2 , stride = 2 , ** dd )] if i > 0 else []) +
546+ # Conv blocks
547+ [Block (dim = dim , drop_path = next (dp_iter ), ** dd ) for _ in range (depth - t_depth )] +
548+ # Transformer blocks at end of stage
549+ [TransformerBlock (inp = dim , oup = dim , drop_path = next (dp_iter ), ** dd ) for _ in range (t_depth )] +
550+ # Trailing LayerNorm (except last stage)
551+ ([LayerNorm2d (dim , eps = 1e-6 , ** dd )] if i < len (depths ) - 1 else [])
552+ )
553+ stages .append (nn .Sequential (* layers ))
554+ self .stages = nn .Sequential (* stages )
557555
558556 self .head = NormMlpClassifierHead (dims [- 1 ], num_classes , pool_type = global_pool , ** dd )
559557
@@ -748,20 +746,28 @@ def remap_stage(m):
748746 elif '.attn_norm.' in k :
749747 k = k .replace ('.attn_norm.' , '.norm1.' )
750748
751- # SpatialTransformerBlock: flatten .attention.attention.attn. -> .attention.attention.
752- # and remap to_qkv -> qkv
753- if '.attention.attention.attn.' in k :
754- k = k .replace ('.attention.attention.attn.to_qkv.' , '.attention.attention.qkv.' )
755- k = k .replace ('.attention.attention.attn.' , '.attention.attention.' )
749+ # Block.attention -> Block.attn (SpatialAttention)
750+ # SpatialAttention.attention -> SpatialAttention.attn (SpatialTransformerBlock)
751+ # Handle nested .attention.attention. first, then remaining .attention.
752+ if '.attention.attention.' in k :
753+ # SpatialTransformerBlock inner attn: remap to_qkv -> qkv
754+ k = k .replace ('.attention.attention.attn.to_qkv.' , '.attn.attn.qkv.' )
755+ k = k .replace ('.attention.attention.attn.' , '.attn.attn.' )
756+ k = k .replace ('.attention.attention.' , '.attn.attn.' )
757+ elif '.attention.' in k :
758+ # Block.attention -> Block.attn (catches SpatialAttention.conv etc)
759+ k = k .replace ('.attention.' , '.attn.' )
756760
757761 # TransformerBlock: remap attention layer names
758762 # to_qkv -> qkv, to_out.0 -> proj, attn.pos_embed -> pos_embed
763+ # Note: only for TransformerBlock, not SpatialTransformerBlock (which has .attn.attn.)
759764 if '.attn.to_qkv.' in k :
760765 k = k .replace ('.attn.to_qkv.' , '.attn.qkv.' )
761766 elif '.attn.to_out.0.' in k :
762767 k = k .replace ('.attn.to_out.0.' , '.attn.proj.' )
763768
764- if '.attn.pos_embed.' in k :
769+ # TransformerBlock: .attn.pos_embed -> .pos_embed (but not .attn.attn.pos_embed)
770+ if '.attn.pos_embed.' in k and '.attn.attn.' not in k :
765771 k = k .replace ('.attn.pos_embed.' , '.pos_embed.' )
766772
767773 # Remap head -> head.fc, norm -> head.norm (order matters)
0 commit comments