1515
1616from timm .data import IMAGENET_DEFAULT_MEAN , IMAGENET_DEFAULT_STD
1717from timm .models .helpers import load_pretrained
18- from timm .models .layers import DropPath , to_2tuple , trunc_normal_
18+ from timm .models .layers import PatchEmbed , Mlp , DropPath , to_2tuple , trunc_normal_
1919from timm .models .registry import register_model
2020
2121from functools import partial
@@ -54,26 +54,6 @@ def _cfg_coat(url='', **kwargs):
5454}
5555
5656
57- class Mlp (nn .Module ):
58- """ Feed-forward network (FFN, a.k.a. MLP) class. """
59- def __init__ (self , in_features , hidden_features = None , out_features = None , act_layer = nn .GELU , drop = 0. ):
60- super ().__init__ ()
61- out_features = out_features or in_features
62- hidden_features = hidden_features or in_features
63- self .fc1 = nn .Linear (in_features , hidden_features )
64- self .act = act_layer ()
65- self .fc2 = nn .Linear (hidden_features , out_features )
66- self .drop = nn .Dropout (drop )
67-
68- def forward (self , x ):
69- x = self .fc1 (x )
70- x = self .act (x )
71- x = self .drop (x )
72- x = self .fc2 (x )
73- x = self .drop (x )
74- return x
75-
76-
7757class ConvRelPosEnc (nn .Module ):
7858 """ Convolutional relative position encoding. """
7959 def __init__ (self , Ch , h , window ):
@@ -348,34 +328,6 @@ def forward(self, x1, x2, x3, x4, sizes):
348328 return x1 , x2 , x3 , x4
349329
350330
351- class PatchEmbed (nn .Module ):
352- """ Image to Patch Embedding """
353- def __init__ (self , img_size = 224 , patch_size = 16 , in_chans = 3 , embed_dim = 768 ):
354- super ().__init__ ()
355- img_size = to_2tuple (img_size )
356- patch_size = to_2tuple (patch_size )
357-
358- self .img_size = img_size
359- self .patch_size = patch_size
360- assert img_size [0 ] % patch_size [0 ] == 0 and img_size [1 ] % patch_size [1 ] == 0 , \
361- f"img_size { img_size } should be divided by patch_size { patch_size } ."
362- # Note: self.H, self.W and self.num_patches are not used
363- # since the image size may change on the fly.
364- self .H , self .W = img_size [0 ] // patch_size [0 ], img_size [1 ] // patch_size [1 ]
365- self .num_patches = self .H * self .W
366- self .proj = nn .Conv2d (in_chans , embed_dim , kernel_size = patch_size , stride = patch_size )
367- self .norm = nn .LayerNorm (embed_dim )
368-
369- def forward (self , x ):
370- _ , _ , H , W = x .shape
371- out_H , out_W = H // self .patch_size [0 ], W // self .patch_size [1 ]
372-
373- x = self .proj (x ).flatten (2 ).transpose (1 , 2 )
374- out = self .norm (x )
375-
376- return out , (out_H , out_W )
377-
378-
379331class CoaT (nn .Module ):
380332 """ CoaT class. """
381333 def __init__ (self , img_size = 224 , patch_size = 16 , in_chans = 3 , num_classes = 1000 , embed_dims = [0 , 0 , 0 , 0 ],
@@ -391,13 +343,17 @@ def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, em
391343
392344 # Patch embeddings.
393345 self .patch_embed1 = PatchEmbed (
394- img_size = img_size , patch_size = patch_size , in_chans = in_chans , embed_dim = embed_dims [0 ])
346+ img_size = img_size , patch_size = patch_size , in_chans = in_chans ,
347+ embed_dim = embed_dims [0 ], norm_layer = nn .LayerNorm )
395348 self .patch_embed2 = PatchEmbed (
396- img_size = img_size // 4 , patch_size = 2 , in_chans = embed_dims [0 ], embed_dim = embed_dims [1 ])
349+ img_size = img_size // 4 , patch_size = 2 , in_chans = embed_dims [0 ],
350+ embed_dim = embed_dims [1 ], norm_layer = nn .LayerNorm )
397351 self .patch_embed3 = PatchEmbed (
398- img_size = img_size // 8 , patch_size = 2 , in_chans = embed_dims [1 ], embed_dim = embed_dims [2 ])
352+ img_size = img_size // 8 , patch_size = 2 , in_chans = embed_dims [1 ],
353+ embed_dim = embed_dims [2 ], norm_layer = nn .LayerNorm )
399354 self .patch_embed4 = PatchEmbed (
400- img_size = img_size // 16 , patch_size = 2 , in_chans = embed_dims [2 ], embed_dim = embed_dims [3 ])
355+ img_size = img_size // 16 , patch_size = 2 , in_chans = embed_dims [2 ],
356+ embed_dim = embed_dims [3 ], norm_layer = nn .LayerNorm )
401357
402358 # Class tokens.
403359 self .cls_token1 = nn .Parameter (torch .zeros (1 , 1 , embed_dims [0 ]))
@@ -533,31 +489,35 @@ def forward_features(self, x0):
533489 B = x0 .shape [0 ]
534490
535491 # Serial blocks 1.
536- x1 , (H1 , W1 ) = self .patch_embed1 (x0 )
492+ x1 = self .patch_embed1 (x0 )
493+ H1 , W1 = self .patch_embed1 .out_size
537494 x1 = self .insert_cls (x1 , self .cls_token1 )
538495 for blk in self .serial_blocks1 :
539496 x1 = blk (x1 , size = (H1 , W1 ))
540497 x1_nocls = self .remove_cls (x1 )
541498 x1_nocls = x1_nocls .reshape (B , H1 , W1 , - 1 ).permute (0 , 3 , 1 , 2 ).contiguous ()
542499
543500 # Serial blocks 2.
544- x2 , (H2 , W2 ) = self .patch_embed2 (x1_nocls )
501+ x2 = self .patch_embed2 (x1_nocls )
502+ H2 , W2 = self .patch_embed2 .out_size
545503 x2 = self .insert_cls (x2 , self .cls_token2 )
546504 for blk in self .serial_blocks2 :
547505 x2 = blk (x2 , size = (H2 , W2 ))
548506 x2_nocls = self .remove_cls (x2 )
549507 x2_nocls = x2_nocls .reshape (B , H2 , W2 , - 1 ).permute (0 , 3 , 1 , 2 ).contiguous ()
550508
551509 # Serial blocks 3.
552- x3 , (H3 , W3 ) = self .patch_embed3 (x2_nocls )
510+ x3 = self .patch_embed3 (x2_nocls )
511+ H3 , W3 = self .patch_embed3 .out_size
553512 x3 = self .insert_cls (x3 , self .cls_token3 )
554513 for blk in self .serial_blocks3 :
555514 x3 = blk (x3 , size = (H3 , W3 ))
556515 x3_nocls = self .remove_cls (x3 )
557516 x3_nocls = x3_nocls .reshape (B , H3 , W3 , - 1 ).permute (0 , 3 , 1 , 2 ).contiguous ()
558517
559518 # Serial blocks 4.
560- x4 , (H4 , W4 ) = self .patch_embed4 (x3_nocls )
519+ x4 = self .patch_embed4 (x3_nocls )
520+ H4 , W4 = self .patch_embed4 .out_size
561521 x4 = self .insert_cls (x4 , self .cls_token4 )
562522 for blk in self .serial_blocks4 :
563523 x4 = blk (x4 , size = (H4 , W4 ))
0 commit comments