22
33import torch
44from torch import is_tensor , randn
5- from torch .nn import Module , Parameter
5+ from torch .nn import Module , Linear , Parameter
66from torch .utils ._pytree import tree_flatten , tree_unflatten
77
88from einops import rearrange , repeat
@@ -26,7 +26,8 @@ def __init__(
2626 dim_emb = None ,
2727 time_seq_len = None ,
2828 embed_is_channel_first = False ,
29- output_pos_add_pos_emb = 0 # defaults to first output position to add embedding
29+ output_pos_add_pos_emb = 0 , # defaults to first output position to add embedding
30+ proj_embed_to_dim = None
3031 ):
3132 super ().__init__ ()
3233 self .image_net = image_net
@@ -35,11 +36,23 @@ def __init__(
3536 self .add_time_pos_emb = add_time_pos_emb
3637 self .output_pos_add_pos_emb = output_pos_add_pos_emb
3738
39+ # maybe project the image embedding
40+
41+ self .embed_proj = None
42+
43+ if exists (proj_embed_to_dim ):
44+ assert exists (dim_emb ), '`dim_emb` must be passed in'
45+ self .embed_proj = Linear (dim_emb , proj_embed_to_dim )
46+
47+ # time positional embedding
48+
3849 if add_time_pos_emb :
3950 assert exists (dim_emb ) and exists (time_seq_len ), '`dim_emb` and `time_seq_len` must be set if adding positional embeddings to the output'
4051 self .time_seq_len = time_seq_len
4152
42- self .pos_emb = Parameter (randn (time_seq_len , dim_emb ) * 1e-2 )
53+ dim_pos_emb = default (proj_embed_to_dim , dim_emb )
54+
55+ self .pos_emb = Parameter (randn (time_seq_len , dim_pos_emb ) * 1e-2 )
4356
4457 self .embed_is_channel_first = embed_is_channel_first
4558
@@ -79,6 +92,15 @@ def forward(
7992
8093 outputs = tuple (rearrange (t , '(b t) ... -> b t ...' , t = time ) if is_tensor (t ) and t .numel () > 1 else t for t in outputs )
8194
95+ # maybe project embedding
96+
97+ if exists (self .embed_proj ):
98+ outputs = list (outputs )
99+
100+ embed = outputs [self .output_pos_add_pos_emb ]
101+
102+ outputs [self .output_pos_add_pos_emb ] = self .embed_proj (embed )
103+
82104 # maybe add time positional embedding
83105
84106 if add_time_pos_emb :
@@ -131,9 +153,9 @@ def forward(
131153 from vit_pytorch .extractor import Extractor
132154 v = Extractor (v )
133155
134- video_acceptor = AcceptVideoWrapper (v , add_time_pos_emb = True , output_pos_add_pos_emb = 1 , time_seq_len = 12 , dim_emb = 1024 )
156+ video_acceptor = AcceptVideoWrapper (v , add_time_pos_emb = True , output_pos_add_pos_emb = 1 , time_seq_len = 12 , dim_emb = 1024 , proj_embed_to_dim = 512 )
135157
136158 logits , embeddings = video_acceptor (videos , eval_with_no_grad = True ) # always (batch, channels, time, height, width) - time is always dimension 2
137159
138160 assert logits .shape == (1 , 7 , 1000 )
139- assert embeddings .shape == (1 , 7 , 65 , 1024 )
161+ assert embeddings .shape == (1 , 7 , 65 , 512 )
0 commit comments