11import torch
22from torch import nn
33
4- from einops import rearrange , repeat
4+ from einops import rearrange , repeat , reduce
55from einops .layers .torch import Rearrange
66
77# helpers
88
9+ def exists (val ):
10+ return val is not None
11+
912def pair (t ):
1013 return t if isinstance (t , tuple ) else (t , t )
1114
@@ -106,20 +109,25 @@ def __init__(
106109 assert image_height % patch_height == 0 and image_width % patch_width == 0 , 'Image dimensions must be divisible by the patch size.'
107110 assert frames % frame_patch_size == 0 , 'Frames must be divisible by frame patch size'
108111
109- num_patches = (image_height // patch_height ) * (image_width // patch_width ) * (frames // frame_patch_size )
112+ num_image_patches = (image_height // patch_height ) * (image_width // patch_width )
113+ num_frame_patches = (frames // frame_patch_size )
114+
110115 patch_dim = channels * patch_height * patch_width * frame_patch_size
111116
112117 assert pool in {'cls' , 'mean' }, 'pool type must be either cls (cls token) or mean (mean pooling)'
113118
119+ self .global_average_pool = pool == 'mean'
120+
114121 self .to_patch_embedding = nn .Sequential (
115122 Rearrange ('b c (f pf) (h p1) (w p2) -> b f (h w) (p1 p2 pf c)' , p1 = patch_height , p2 = patch_width , pf = frame_patch_size ),
116123 nn .Linear (patch_dim , dim ),
117124 )
118125
119- self .pos_embedding = nn .Parameter (torch .randn (1 , num_patches + 1 , dim ))
126+ self .pos_embedding = nn .Parameter (torch .randn (1 , num_frame_patches , num_image_patches , dim ))
120127 self .dropout = nn .Dropout (emb_dropout )
121- self .spatial_cls_token = nn .Parameter (torch .randn (1 , 1 , dim ))
122- self .temporal_cls_token = nn .Parameter (torch .randn (1 , 1 , dim ))
128+
129+ self .spatial_cls_token = nn .Parameter (torch .randn (1 , 1 , dim )) if not self .global_average_pool else None
130+ self .temporal_cls_token = nn .Parameter (torch .randn (1 , 1 , dim )) if not self .global_average_pool else None
123131
124132 self .spatial_transformer = Transformer (dim , spatial_depth , heads , dim_head , mlp_dim , dropout )
125133 self .temporal_transformer = Transformer (dim , temporal_depth , heads , dim_head , mlp_dim , dropout )
@@ -132,13 +140,16 @@ def __init__(
132140 nn .Linear (dim , num_classes )
133141 )
134142
135- def forward (self , img ):
136- x = self .to_patch_embedding (img )
143+ def forward (self , video ):
144+ x = self .to_patch_embedding (video )
137145 b , f , n , _ = x .shape
138146
139- spatial_cls_tokens = repeat (self .spatial_cls_token , '1 1 d -> b f 1 d' , b = b , f = f )
140- x = torch .cat ((spatial_cls_tokens , x ), dim = 2 )
141- x += self .pos_embedding [:, :(n + 1 )]
147+ x = x + self .pos_embedding
148+
149+ if exists (self .spatial_cls_token ):
150+ spatial_cls_tokens = repeat (self .spatial_cls_token , '1 1 d -> b f 1 d' , b = b , f = f )
151+ x = torch .cat ((spatial_cls_tokens , x ), dim = 2 )
152+
142153 x = self .dropout (x )
143154
144155 x = rearrange (x , 'b f n d -> (b f) n d' )
@@ -149,21 +160,24 @@ def forward(self, img):
149160
150161 x = rearrange (x , '(b f) n d -> b f n d' , b = b )
151162
152- # excise out the spatial cls tokens for temporal attention
163+ # excise out the spatial cls tokens or average pool for temporal attention
153164
154- x = x [:, :, 0 ]
165+ x = x [:, :, 0 ] if not self . global_average_pool else reduce ( x , 'b f n d -> b f d' , 'mean' )
155166
156167 # append temporal CLS tokens
157168
158- temporal_cls_tokens = repeat (self .temporal_cls_token , '1 1 d-> b 1 d' , b = b )
169+ if exists (self .temporal_cls_token ):
170+ temporal_cls_tokens = repeat (self .temporal_cls_token , '1 1 d-> b 1 d' , b = b )
159171
160- x = torch .cat ((temporal_cls_tokens , x ), dim = 1 )
172+ x = torch .cat ((temporal_cls_tokens , x ), dim = 1 )
161173
162174 # attend across time
163175
164176 x = self .temporal_transformer (x )
165177
166- x = x .mean (dim = 1 ) if self .pool == 'mean' else x [:, 0 ]
178+ # excise out temporal cls token or average pool
179+
180+ x = x [:, 0 ] if not self .global_average_pool else reduce (x , 'b f d -> b d' , 'mean' )
167181
168182 x = self .to_latent (x )
169183 return self .mlp_head (x )
0 commit comments