Skip to content

Commit 03a6244

Browse files
normsterchhluo
authored andcommitted
[fix]: Fix swin backbone absolute pos_embed (#8127)
* Fix swin backbone absolute pos_embed resizing * fix lint * fix lint * add unit test * Update swin.py Co-authored-by: Cedric Luo <[email protected]>
1 parent 9cd074a commit 03a6244

File tree

2 files changed

+17
-3
lines changed

2 files changed

+17
-3
lines changed

mmdet/models/backbones/swin.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -588,9 +588,8 @@ def __init__(self,
588588
if self.use_abs_pos_embed:
589589
patch_row = pretrain_img_size[0] // patch_size
590590
patch_col = pretrain_img_size[1] // patch_size
591-
num_patches = patch_row * patch_col
592591
self.absolute_pos_embed = nn.Parameter(
593-
torch.zeros((1, num_patches, embed_dims)))
592+
torch.zeros((1, embed_dims, patch_row, patch_col)))
594593

595594
self.drop_after_pos = nn.Dropout(p=drop_rate)
596595

@@ -746,7 +745,17 @@ def forward(self, x):
746745
x, hw_shape = self.patch_embed(x)
747746

748747
if self.use_abs_pos_embed:
749-
x = x + self.absolute_pos_embed
748+
h, w = self.absolute_pos_embed.shape[1:3]
749+
if hw_shape[0] != h or hw_shape[1] != w:
750+
absolute_pos_embed = F.interpolate(
751+
self.absolute_pos_embed,
752+
size=hw_shape,
753+
mode='bicubic',
754+
align_corners=False).flatten(2).transpose(1, 2)
755+
else:
756+
absolute_pos_embed = self.absolute_pos_embed.flatten(
757+
2).transpose(1, 2)
758+
x = x + absolute_pos_embed
750759
x = self.drop_after_pos(x)
751760

752761
outs = []

tests/test_models/test_backbones/test_swin.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,11 @@ def test_swin_transformer():
4444
model = SwinTransformer(pretrain_img_size=224, use_abs_pos_embed=True)
4545
model.init_weights()
4646
model(temp)
47+
# Test different inputs when use absolute position embedding
48+
temp = torch.randn((1, 3, 112, 112))
49+
model(temp)
50+
temp = torch.randn((1, 3, 256, 256))
51+
model(temp)
4752

4853
# Test patch norm
4954
model = SwinTransformer(patch_norm=False)

0 commit comments

Comments
 (0)