11import torch
2- import torch . nn as nn
2+ from torch import nn , einsum
33import torch .nn .functional as F
44
5+ from einops import rearrange , repeat
6+
57# helpers
68
9+ def exists (val ):
10+ return val is not None
11+
12+ def default (val , d ):
13+ return val if exists (val ) else d
14+
715def pair (t ):
816 return t if isinstance (t , tuple ) else (t , t )
917
@@ -50,8 +58,9 @@ def cct_16(*args, **kwargs):
5058def _cct (num_layers , num_heads , mlp_ratio , embedding_dim ,
5159 kernel_size = 3 , stride = None , padding = None ,
5260 * args , ** kwargs ):
53- stride = stride if stride is not None else max (1 , (kernel_size // 2 ) - 1 )
54- padding = padding if padding is not None else max (1 , (kernel_size // 2 ))
61+ stride = default (stride , max (1 , (kernel_size // 2 ) - 1 ))
62+ padding = default (padding , max (1 , (kernel_size // 2 )))
63+
5564 return CCT (num_layers = num_layers ,
5665 num_heads = num_heads ,
5766 mlp_ratio = mlp_ratio ,
@@ -61,13 +70,22 @@ def _cct(num_layers, num_heads, mlp_ratio, embedding_dim,
6170 padding = padding ,
6271 * args , ** kwargs )
6372
73+ # positional
74+
75+ def sinusoidal_embedding (n_channels , dim ):
76+ pe = torch .FloatTensor ([[p / (10000 ** (2 * (i // 2 ) / dim )) for i in range (dim )]
77+ for p in range (n_channels )])
78+ pe [:, 0 ::2 ] = torch .sin (pe [:, 0 ::2 ])
79+ pe [:, 1 ::2 ] = torch .cos (pe [:, 1 ::2 ])
80+ return rearrange (pe , '... -> 1 ...' )
81+
6482# modules
6583
6684class Attention (nn .Module ):
6785 def __init__ (self , dim , num_heads = 8 , attention_dropout = 0.1 , projection_dropout = 0.1 ):
6886 super ().__init__ ()
69- self .num_heads = num_heads
70- head_dim = dim // self .num_heads
87+ self .heads = num_heads
88+ head_dim = dim // self .heads
7189 self .scale = head_dim ** - 0.5
7290
7391 self .qkv = nn .Linear (dim , dim * 3 , bias = False )
@@ -77,17 +95,20 @@ def __init__(self, dim, num_heads=8, attention_dropout=0.1, projection_dropout=0
7795
7896 def forward (self , x ):
7997 B , N , C = x .shape
80- qkv = self .qkv (x ).reshape (B , N , 3 , self .num_heads , C // self .num_heads ).permute (2 , 0 , 3 , 1 , 4 )
81- q , k , v = qkv [0 ], qkv [1 ], qkv [2 ]
8298
83- attn = (q @ k .transpose (- 2 , - 1 )) * self .scale
99+ qkv = self .qkv (x ).chunk (3 , dim = - 1 )
100+ q , k , v = map (lambda t : rearrange (t , 'b n (h d) -> b h n d' , h = self .heads ), qkv )
101+
102+ q = q * self .scale
103+
104+ attn = einsum ('b h i d, b h j d -> b h i j' , q , k )
84105 attn = attn .softmax (dim = - 1 )
85106 attn = self .attn_drop (attn )
86107
87- x = ( attn @ v ). transpose ( 1 , 2 ). reshape ( B , N , C )
88- x = self . proj ( x )
89- x = self . proj_drop ( x )
90- return x
108+ x = einsum ( 'b h i j, b h j d -> b h i d' , attn , v )
109+ x = rearrange ( x , 'b h n d -> b n (h d)' )
110+
111+ return self . proj_drop ( self . proj ( x ))
91112
92113
93114class TransformerEncoderLayer (nn .Module ):
@@ -97,7 +118,8 @@ class TransformerEncoderLayer(nn.Module):
97118 """
98119 def __init__ (self , d_model , nhead , dim_feedforward = 2048 , dropout = 0.1 ,
99120 attention_dropout = 0.1 , drop_path_rate = 0.1 ):
100- super (TransformerEncoderLayer , self ).__init__ ()
121+ super ().__init__ ()
122+
101123 self .pre_norm = nn .LayerNorm (d_model )
102124 self .self_attn = Attention (dim = d_model , num_heads = nhead ,
103125 attention_dropout = attention_dropout , projection_dropout = dropout )
@@ -108,50 +130,34 @@ def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
108130 self .linear2 = nn .Linear (dim_feedforward , d_model )
109131 self .dropout2 = nn .Dropout (dropout )
110132
111- self .drop_path = DropPath (drop_path_rate ) if drop_path_rate > 0 else nn . Identity ()
133+ self .drop_path = DropPath (drop_path_rate )
112134
113135 self .activation = F .gelu
114136
115- def forward (self , src : torch . Tensor , * args , ** kwargs ) -> torch . Tensor :
137+ def forward (self , src , * args , ** kwargs ):
116138 src = src + self .drop_path (self .self_attn (self .pre_norm (src )))
117139 src = self .norm1 (src )
118140 src2 = self .linear2 (self .dropout1 (self .activation (self .linear1 (src ))))
119141 src = src + self .drop_path (self .dropout2 (src2 ))
120142 return src
121143
122-
123- def drop_path (x , drop_prob : float = 0. , training : bool = False ):
124- """
125- Obtained from: github.com:rwightman/pytorch-image-models
126- Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
127- This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
128- the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
129- See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
130- changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
131- 'survival rate' as the argument.
132- """
133- if drop_prob == 0. or not training :
134- return x
135- keep_prob = 1 - drop_prob
136- shape = (x .shape [0 ],) + (1 ,) * (x .ndim - 1 ) # work with diff dim tensors, not just 2D ConvNets
137- random_tensor = keep_prob + torch .rand (shape , dtype = x .dtype , device = x .device )
138- random_tensor .floor_ () # binarize
139- output = x .div (keep_prob ) * random_tensor
140- return output
141-
142-
143144class DropPath (nn .Module ):
144- """
145- Obtained from: github.com:rwightman/pytorch-image-models
146- Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
147- """
148145 def __init__ (self , drop_prob = None ):
149- super (DropPath , self ).__init__ ()
150- self .drop_prob = drop_prob
146+ super ().__init__ ()
147+ self .drop_prob = float ( drop_prob )
151148
152149 def forward (self , x ):
153- return drop_path (x , self .drop_prob , self .training )
150+ batch , drop_prob , device , dtype = x .shape [0 ], self .drop_prob , x .device , x .dtype
151+
152+ if drop_prob <= 0. or not self .training :
153+ return x
154+
155+ keep_prob = 1 - self .drop_prob
156+ shape = (batch , * ((1 ,) * (x .ndim - 1 )))
154157
158+ keep_mask = torch .zeros (shape , device = device ).float ().uniform_ (0 , 1 ) < keep_prob
159+ output = x .div (keep_prob ) * keep_mask .float ()
160+ return output
155161
156162class Tokenizer (nn .Module ):
157163 def __init__ (self ,
@@ -164,34 +170,35 @@ def __init__(self,
164170 activation = None ,
165171 max_pool = True ,
166172 conv_bias = False ):
167- super (Tokenizer , self ).__init__ ()
173+ super ().__init__ ()
168174
169175 n_filter_list = [n_input_channels ] + \
170176 [in_planes for _ in range (n_conv_layers - 1 )] + \
171177 [n_output_channels ]
172178
179+ n_filter_list_pairs = zip (n_filter_list [:- 1 ], n_filter_list [1 :])
180+
173181 self .conv_layers = nn .Sequential (
174182 * [nn .Sequential (
175- nn .Conv2d (n_filter_list [ i ], n_filter_list [ i + 1 ] ,
183+ nn .Conv2d (chan_in , chan_out ,
176184 kernel_size = (kernel_size , kernel_size ),
177185 stride = (stride , stride ),
178186 padding = (padding , padding ), bias = conv_bias ),
179- nn .Identity () if activation is None else activation (),
187+ nn .Identity () if not exists ( activation ) else activation (),
180188 nn .MaxPool2d (kernel_size = pooling_kernel_size ,
181189 stride = pooling_stride ,
182190 padding = pooling_padding ) if max_pool else nn .Identity ()
183191 )
184- for i in range ( n_conv_layers )
192+ for chan_in , chan_out in n_filter_list_pairs
185193 ])
186194
187- self .flattener = nn .Flatten (2 , 3 )
188195 self .apply (self .init_weight )
189196
190197 def sequence_length (self , n_channels = 3 , height = 224 , width = 224 ):
191198 return self .forward (torch .zeros ((1 , n_channels , height , width ))).shape [1 ]
192199
193200 def forward (self , x ):
194- return self . flattener (self .conv_layers (x )). transpose ( - 2 , - 1 )
201+ return rearrange (self .conv_layers (x ), 'b c h w -> b (h w) c' )
195202
196203 @staticmethod
197204 def init_weight (m ):
@@ -214,106 +221,104 @@ def __init__(self,
214221 sequence_length = None ,
215222 * args , ** kwargs ):
216223 super ().__init__ ()
217- positional_embedding = positional_embedding if \
218- positional_embedding in [ 'sine' , 'learnable' , 'none' ] else 'sine'
224+ assert positional_embedding in { 'sine' , 'learnable' , 'none' }
225+
219226 dim_feedforward = int (embedding_dim * mlp_ratio )
220227 self .embedding_dim = embedding_dim
221228 self .sequence_length = sequence_length
222229 self .seq_pool = seq_pool
223230
224- assert sequence_length is not None or positional_embedding == 'none' , \
231+ assert exists ( sequence_length ) or positional_embedding == 'none' , \
225232 f"Positional embedding is set to { positional_embedding } and" \
226233 f" the sequence length was not specified."
227234
228235 if not seq_pool :
229236 sequence_length += 1
230- self .class_emb = nn .Parameter (torch .zeros (1 , 1 , self .embedding_dim ),
231- requires_grad = True )
237+ self .class_emb = nn .Parameter (torch .zeros (1 , 1 , self .embedding_dim ), requires_grad = True )
232238 else :
233239 self .attention_pool = nn .Linear (self .embedding_dim , 1 )
234240
235- if positional_embedding != 'none' :
236- if positional_embedding == 'learnable' :
237- self .positional_emb = nn .Parameter (torch .zeros (1 , sequence_length , embedding_dim ),
238- requires_grad = True )
239- nn .init .trunc_normal_ (self .positional_emb , std = 0.2 )
240- else :
241- self .positional_emb = nn .Parameter (self .sinusoidal_embedding (sequence_length , embedding_dim ),
242- requires_grad = False )
243- else :
241+ if positional_embedding == 'none' :
244242 self .positional_emb = None
243+ elif positional_embedding == 'learnable' :
244+ self .positional_emb = nn .Parameter (torch .zeros (1 , sequence_length , embedding_dim ),
245+ requires_grad = True )
246+ nn .init .trunc_normal_ (self .positional_emb , std = 0.2 )
247+ else :
248+ self .positional_emb = nn .Parameter (sinusoidal_embedding (sequence_length , embedding_dim ),
249+ requires_grad = False )
245250
246251 self .dropout = nn .Dropout (p = dropout_rate )
252+
247253 dpr = [x .item () for x in torch .linspace (0 , stochastic_depth_rate , num_layers )]
254+
248255 self .blocks = nn .ModuleList ([
249256 TransformerEncoderLayer (d_model = embedding_dim , nhead = num_heads ,
250257 dim_feedforward = dim_feedforward , dropout = dropout_rate ,
251- attention_dropout = attention_dropout , drop_path_rate = dpr [i ])
252- for i in range (num_layers )])
258+ attention_dropout = attention_dropout , drop_path_rate = layer_dpr )
259+ for layer_dpr in dpr ])
260+
253261 self .norm = nn .LayerNorm (embedding_dim )
254262
255263 self .fc = nn .Linear (embedding_dim , num_classes )
256264 self .apply (self .init_weight )
257265
258266 def forward (self , x ):
259- if self .positional_emb is None and x .size (1 ) < self .sequence_length :
267+ b = x .shape [0 ]
268+
269+ if not exists (self .positional_emb ) and x .size (1 ) < self .sequence_length :
260270 x = F .pad (x , (0 , 0 , 0 , self .n_channels - x .size (1 )), mode = 'constant' , value = 0 )
261271
262272 if not self .seq_pool :
263- cls_token = self .class_emb . expand ( x . shape [ 0 ], - 1 , - 1 )
273+ cls_token = repeat ( self .class_emb , '1 1 d -> b 1 d' , b = b )
264274 x = torch .cat ((cls_token , x ), dim = 1 )
265275
266- if self .positional_emb is not None :
276+ if exists ( self .positional_emb ) :
267277 x += self .positional_emb
268278
269279 x = self .dropout (x )
270280
271281 for blk in self .blocks :
272282 x = blk (x )
283+
273284 x = self .norm (x )
274285
275286 if self .seq_pool :
276- x = torch .matmul (F .softmax (self .attention_pool (x ), dim = 1 ).transpose (- 1 , - 2 ), x ).squeeze (- 2 )
287+ attn_weights = rearrange (self .attention_pool (x ), 'b n 1 -> b n' )
288+ x = einsum ('b n, b n d -> b d' , attn_weights .softmax (dim = 1 ), x )
277289 else :
278290 x = x [:, 0 ]
279291
280- x = self .fc (x )
281- return x
292+ return self .fc (x )
282293
283294 @staticmethod
284295 def init_weight (m ):
285296 if isinstance (m , nn .Linear ):
286297 nn .init .trunc_normal_ (m .weight , std = .02 )
287- if isinstance (m , nn .Linear ) and m .bias is not None :
298+ if isinstance (m , nn .Linear ) and exists ( m .bias ) :
288299 nn .init .constant_ (m .bias , 0 )
289300 elif isinstance (m , nn .LayerNorm ):
290301 nn .init .constant_ (m .bias , 0 )
291302 nn .init .constant_ (m .weight , 1.0 )
292303
293- @staticmethod
294- def sinusoidal_embedding (n_channels , dim ):
295- pe = torch .FloatTensor ([[p / (10000 ** (2 * (i // 2 ) / dim )) for i in range (dim )]
296- for p in range (n_channels )])
297- pe [:, 0 ::2 ] = torch .sin (pe [:, 0 ::2 ])
298- pe [:, 1 ::2 ] = torch .cos (pe [:, 1 ::2 ])
299- return pe .unsqueeze (0 )
300-
301-
302304# CCT Main model
305+
303306class CCT (nn .Module ):
304- def __init__ (self ,
305- img_size = 224 ,
306- embedding_dim = 768 ,
307- n_input_channels = 3 ,
308- n_conv_layers = 1 ,
309- kernel_size = 7 ,
310- stride = 2 ,
311- padding = 3 ,
312- pooling_kernel_size = 3 ,
313- pooling_stride = 2 ,
314- pooling_padding = 1 ,
315- * args , ** kwargs ):
316- super (CCT , self ).__init__ ()
307+ def __init__ (
308+ self ,
309+ img_size = 224 ,
310+ embedding_dim = 768 ,
311+ n_input_channels = 3 ,
312+ n_conv_layers = 1 ,
313+ kernel_size = 7 ,
314+ stride = 2 ,
315+ padding = 3 ,
316+ pooling_kernel_size = 3 ,
317+ pooling_stride = 2 ,
318+ pooling_padding = 1 ,
319+ * args , ** kwargs
320+ ):
321+ super ().__init__ ()
317322 img_height , img_width = pair (img_size )
318323
319324 self .tokenizer = Tokenizer (n_input_channels = n_input_channels ,
0 commit comments