Skip to content

Commit b6eb61a

Browse files
committed
Another round of consistency changes for csatv2, make stage building dynamic for other network shapes, allow drop path option for transformer blocks.
1 parent 6cec2f0 commit b6eb61a

File tree

1 file changed

+66
-60
lines changed

1 file changed

+66
-60
lines changed

timm/models/csatv2.py

Lines changed: 66 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -288,7 +288,7 @@ def __init__(
288288
self.grn = GlobalResponseNorm(4 * dim, channels_last=True, **dd)
289289
self.pwconv2 = nn.Linear(4 * dim, dim, **dd)
290290
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
291-
self.attention = SpatialAttention(**dd)
291+
self.attn = SpatialAttention(**dd)
292292

293293
def forward(self, x: torch.Tensor) -> torch.Tensor:
294294
shortcut = x
@@ -301,9 +301,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
301301
x = self.pwconv2(x)
302302
x = x.permute(0, 3, 1, 2)
303303

304-
attention = self.attention(x)
305-
up_attn = F.interpolate(attention, size=x.shape[2:], mode='bilinear', align_corners=True)
306-
x = x * up_attn
304+
attn = self.attn(x)
305+
attn = F.interpolate(attn, size=x.shape[2:], mode='bilinear', align_corners=True)
306+
x = x * attn
307307

308308
return shortcut + self.drop_path(x)
309309

@@ -371,15 +371,15 @@ def __init__(
371371
super().__init__()
372372
self.avgpool = nn.AdaptiveAvgPool2d((7, 7))
373373
self.conv = nn.Conv2d(2, 1, kernel_size=7, padding=3, **dd)
374-
self.attention = SpatialTransformerBlock(**dd)
374+
self.attn = SpatialTransformerBlock(**dd)
375375

376376
def forward(self, x: torch.Tensor) -> torch.Tensor:
377377
x_avg = x.mean(dim=1, keepdim=True)
378378
x_max = x.amax(dim=1, keepdim=True)
379379
x = torch.cat([x_avg, x_max], dim=1)
380380
x = self.avgpool(x)
381381
x = self.conv(x)
382-
x = self.attention(x)
382+
x = self.attn(x)
383383
return x
384384

385385

@@ -395,6 +395,7 @@ def __init__(
395395
downsample: bool = False,
396396
attn_drop: float = 0.,
397397
proj_drop: float = 0.,
398+
drop_path: float = 0.,
398399
device=None,
399400
dtype=None,
400401
) -> None:
@@ -423,9 +424,11 @@ def __init__(
423424
proj_drop=proj_drop,
424425
**dd,
425426
)
427+
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
426428

427429
self.norm2 = nn.LayerNorm(oup, **dd)
428430
self.mlp = Mlp(oup, hidden_dim, oup, act_layer=nn.GELU, drop=proj_drop, **dd)
431+
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
429432

430433
def forward(self, x: torch.Tensor) -> torch.Tensor:
431434
if self.downsample:
@@ -437,7 +440,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
437440
x_t = self.pos_embed(x_t, (H, W))
438441
x_t = self.attn(x_t)
439442
x_t = x_t.transpose(1, 2).reshape(B, -1, H, W)
440-
x = shortcut + x_t
443+
x = shortcut + self.drop_path1(x_t)
441444
else:
442445
B, C, H, W = x.shape
443446
shortcut = x
@@ -446,15 +449,15 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
446449
x_t = self.pos_embed(x_t, (H, W))
447450
x_t = self.attn(x_t)
448451
x_t = x_t.transpose(1, 2).reshape(B, -1, H, W)
449-
x = shortcut + x_t
452+
x = shortcut + self.drop_path1(x_t)
450453

451454
# MLP block
452455
B, C, H, W = x.shape
453456
shortcut = x
454457
x_t = x.flatten(2).transpose(1, 2)
455458
x_t = self.mlp(self.norm2(x_t))
456459
x_t = x_t.transpose(1, 2).reshape(B, C, H, W)
457-
x = shortcut + x_t
460+
x = shortcut + self.drop_path2(x_t)
458461

459462
return x
460463

@@ -491,69 +494,64 @@ def __init__(
491494
self,
492495
num_classes: int = 1000,
493496
in_chans: int = 3,
497+
dims: Tuple[int, ...] = (32, 72, 168, 386),
498+
depths: Tuple[int, ...] = (2, 2, 8, 6),
499+
transformer_depths: Tuple[int, ...] = (0, 0, 2, 2),
494500
drop_path_rate: float = 0.0,
501+
transformer_drop_path: bool = False,
495502
global_pool: str = 'avg',
496503
device=None,
497504
dtype=None,
498505
**kwargs,
499506
) -> None:
500507
dd = dict(device=device, dtype=dtype)
501508
super().__init__()
509+
if in_chans != 3:
510+
warnings.warn(
511+
f'CSATv2 is designed for 3-channel RGB input. '
512+
f'in_chans={in_chans} may not work correctly with the DCT stem.'
513+
)
502514
self.num_classes = num_classes
503515
self.global_pool = global_pool
504516
self.grad_checkpointing = False
505517

506-
dims = [32, 72, 168, 386]
507518
self.num_features = dims[-1]
508519
self.head_hidden_size = self.num_features
509520

510-
self.feature_info = [
511-
dict(num_chs=dims[0], reduction=8, module='stem_dct'),
512-
dict(num_chs=dims[0], reduction=8, module='stages.0'),
513-
dict(num_chs=dims[1], reduction=16, module='stages.1'),
514-
dict(num_chs=dims[2], reduction=32, module='stages.2'),
515-
dict(num_chs=dims[3], reduction=64, module='stages.3'),
516-
]
517-
518-
depths = [2, 2, 6, 4]
519-
dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
521+
# Build feature_info dynamically
522+
self.feature_info = [dict(num_chs=dims[0], reduction=8, module='stem_dct')]
523+
reduction = 8
524+
for i, dim in enumerate(dims):
525+
if i > 0:
526+
reduction *= 2
527+
self.feature_info.append(dict(num_chs=dim, reduction=reduction, module=f'stages.{i}'))
528+
529+
# Build drop path rates for all blocks (0 for transformer blocks when transformer_drop_path=False)
530+
total_blocks = sum(depths) if transformer_drop_path else sum(d - t for d, t in zip(depths, transformer_depths))
531+
dp_iter = iter(torch.linspace(0, drop_path_rate, total_blocks).tolist())
532+
dp_rates = []
533+
for depth, t_depth in zip(depths, transformer_depths):
534+
dp_rates += [next(dp_iter) for _ in range(depth - t_depth)]
535+
dp_rates += [next(dp_iter) if transformer_drop_path else 0. for _ in range(t_depth)]
520536

521537
self.stem_dct = LearnableDct2d(8, **dd)
522538

523-
self.stages = nn.Sequential(
524-
nn.Sequential(
525-
Block(dim=dims[0], drop_path=dp_rates[0], **dd),
526-
Block(dim=dims[0], drop_path=dp_rates[1], **dd),
527-
LayerNorm2d(dims[0], eps=1e-6, **dd),
528-
),
529-
nn.Sequential(
530-
nn.Conv2d(dims[0], dims[1], kernel_size=2, stride=2, **dd),
531-
Block(dim=dims[1], drop_path=dp_rates[2], **dd),
532-
Block(dim=dims[1], drop_path=dp_rates[3], **dd),
533-
LayerNorm2d(dims[1], eps=1e-6, **dd),
534-
),
535-
nn.Sequential(
536-
nn.Conv2d(dims[1], dims[2], kernel_size=2, stride=2, **dd),
537-
Block(dim=dims[2], drop_path=dp_rates[4], **dd),
538-
Block(dim=dims[2], drop_path=dp_rates[5], **dd),
539-
Block(dim=dims[2], drop_path=dp_rates[6], **dd),
540-
Block(dim=dims[2], drop_path=dp_rates[7], **dd),
541-
Block(dim=dims[2], drop_path=dp_rates[8], **dd),
542-
Block(dim=dims[2], drop_path=dp_rates[9], **dd),
543-
TransformerBlock(inp=dims[2], oup=dims[2], **dd),
544-
TransformerBlock(inp=dims[2], oup=dims[2], **dd),
545-
LayerNorm2d(dims[2], eps=1e-6, **dd),
546-
),
547-
nn.Sequential(
548-
nn.Conv2d(dims[2], dims[3], kernel_size=2, stride=2, **dd),
549-
Block(dim=dims[3], drop_path=dp_rates[10], **dd),
550-
Block(dim=dims[3], drop_path=dp_rates[11], **dd),
551-
Block(dim=dims[3], drop_path=dp_rates[12], **dd),
552-
Block(dim=dims[3], drop_path=dp_rates[13], **dd),
553-
TransformerBlock(inp=dims[3], oup=dims[3], **dd),
554-
TransformerBlock(inp=dims[3], oup=dims[3], **dd),
555-
),
556-
)
539+
# Build stages dynamically
540+
dp_iter = iter(dp_rates)
541+
stages = []
542+
for i, (dim, depth, t_depth) in enumerate(zip(dims, depths, transformer_depths)):
543+
layers = (
544+
# Downsample at start of stage (except first stage)
545+
([nn.Conv2d(dims[i - 1], dim, kernel_size=2, stride=2, **dd)] if i > 0 else []) +
546+
# Conv blocks
547+
[Block(dim=dim, drop_path=next(dp_iter), **dd) for _ in range(depth - t_depth)] +
548+
# Transformer blocks at end of stage
549+
[TransformerBlock(inp=dim, oup=dim, drop_path=next(dp_iter), **dd) for _ in range(t_depth)] +
550+
# Trailing LayerNorm (except last stage)
551+
([LayerNorm2d(dim, eps=1e-6, **dd)] if i < len(depths) - 1 else [])
552+
)
553+
stages.append(nn.Sequential(*layers))
554+
self.stages = nn.Sequential(*stages)
557555

558556
self.head = NormMlpClassifierHead(dims[-1], num_classes, pool_type=global_pool, **dd)
559557

@@ -748,20 +746,28 @@ def remap_stage(m):
748746
elif '.attn_norm.' in k:
749747
k = k.replace('.attn_norm.', '.norm1.')
750748

751-
# SpatialTransformerBlock: flatten .attention.attention.attn. -> .attention.attention.
752-
# and remap to_qkv -> qkv
753-
if '.attention.attention.attn.' in k:
754-
k = k.replace('.attention.attention.attn.to_qkv.', '.attention.attention.qkv.')
755-
k = k.replace('.attention.attention.attn.', '.attention.attention.')
749+
# Block.attention -> Block.attn (SpatialAttention)
750+
# SpatialAttention.attention -> SpatialAttention.attn (SpatialTransformerBlock)
751+
# Handle nested .attention.attention. first, then remaining .attention.
752+
if '.attention.attention.' in k:
753+
# SpatialTransformerBlock inner attn: remap to_qkv -> qkv
754+
k = k.replace('.attention.attention.attn.to_qkv.', '.attn.attn.qkv.')
755+
k = k.replace('.attention.attention.attn.', '.attn.attn.')
756+
k = k.replace('.attention.attention.', '.attn.attn.')
757+
elif '.attention.' in k:
758+
# Block.attention -> Block.attn (catches SpatialAttention.conv etc)
759+
k = k.replace('.attention.', '.attn.')
756760

757761
# TransformerBlock: remap attention layer names
758762
# to_qkv -> qkv, to_out.0 -> proj, attn.pos_embed -> pos_embed
763+
# Note: only for TransformerBlock, not SpatialTransformerBlock (which has .attn.attn.)
759764
if '.attn.to_qkv.' in k:
760765
k = k.replace('.attn.to_qkv.', '.attn.qkv.')
761766
elif '.attn.to_out.0.' in k:
762767
k = k.replace('.attn.to_out.0.', '.attn.proj.')
763768

764-
if '.attn.pos_embed.' in k:
769+
# TransformerBlock: .attn.pos_embed -> .pos_embed (but not .attn.attn.pos_embed)
770+
if '.attn.pos_embed.' in k and '.attn.attn.' not in k:
765771
k = k.replace('.attn.pos_embed.', '.pos_embed.')
766772

767773
# Remap head -> head.fc, norm -> head.norm (order matters)

0 commit comments

Comments
 (0)