@@ -21,6 +21,27 @@ def pair(t):
2121
2222# classes
2323
24+ class FiLM (Module ):
25+ def __init__ (
26+ self ,
27+ dim ,
28+ ):
29+ super ().__init__ ()
30+ proj = nn .Linear (dim , dim * 2 )
31+
32+ self .to_gamma_beta = nn .Sequential (
33+ proj ,
34+ Rearrange ('b (two d) -> two b 1 d' , two = 2 )
35+ )
36+
37+ nn .init .zeros_ (proj .weight )
38+ nn .init .zeros_ (proj .bias )
39+
40+ def forward (self , tokens , cond ):
41+ gamma , beta = self .to_gamma_beta (cond )
42+
43+ return tokens * gamma + beta
44+
2445class FeedForward (Module ):
2546 def __init__ (
2647 self ,
@@ -228,6 +249,7 @@ def __init__(
228249 dim_action ,
229250 mlp_dim ,
230251 num_views = None ,
252+ num_tasks = None ,
231253 dim_extra_token = None ,
232254 action_chunk_len = 7 ,
233255 time_seq_len = 1 ,
@@ -266,16 +288,25 @@ def __init__(
266288
267289 self .view_emb = nn .Parameter (torch .randn (num_views , vit_dim ) * 1e-2 ) if exists (num_views ) and num_views > 1 else None
268290
291+ # handle maybe task conditioning
292+
293+ self .has_tasks = exists (num_tasks )
294+
295+ if self .has_tasks :
296+ self .task_emb = nn .Parameter (torch .randn (num_tasks , dim ) * 1e-2 )
297+
269298 # to action tokens
270299
271300 self .action_pos_emb = nn .Parameter (torch .randn (action_chunk_len , dim ) * 1e-2 )
272301
273302 self .layers = ModuleList ([])
274303
275304 for _ in range (depth ):
305+ maybe_film = FiLM (dim = dim ) if self .has_tasks else None
276306 maybe_self_attn = Attention (dim = dim , heads = self_attn_heads , dim_head = self_attn_dim_head , dropout = dropout ) if add_self_attn else None
277307
278308 self .layers .append (ModuleList ([
309+ maybe_film ,
279310 maybe_self_attn ,
280311 Attention (dim = dim , heads = heads , dim_head = dim_head , dropout = dropout , cross_attend = True ),
281312 FeedForward (dim = dim , hidden_dim = mlp_dim , dropout = dropout )
@@ -294,7 +325,9 @@ def __init__(
294325 def forward (
295326 self ,
296327 video_or_image , # (b v? c t? h w) - batch, views [wrist + third person or more], channels, maybe time, height, width
328+ * ,
297329 extra = None , # (b d) - batch, dim extra
330+ tasks = None , # (b)
298331 actions = None , # (b k d) - batch, action chunk length, action dimension
299332 ):
300333 batch = video_or_image .shape [0 ]
@@ -349,6 +382,13 @@ def forward(
349382 view_emb = rearrange (self .view_emb , 'v d -> v 1 1 d' )
350383 hiddens = hiddens + view_emb
351384
385+ # maybe tasks
386+
387+ if exists (tasks ):
388+ assert self .has_tasks , f'`num_tasks` must be set on `VAT` for task conditioning'
389+
390+ task_emb = self .task_emb [tasks ]
391+
352392 # cross from actions to representation trajectory
353393
354394 context = rearrange (hiddens , 'l b v t n d -> l b (v t n) d' )
@@ -368,7 +408,10 @@ def forward(
368408
369409 # cross attention
370410
371- for (maybe_self_attn , cross_attn , ff ), layer_context in zip (self .layers , context ):
411+ for (maybe_film , maybe_self_attn , cross_attn , ff ), layer_context in zip (self .layers , context ):
412+
413+ if exists (tasks ):
414+ action_tokens = maybe_film (action_tokens , task_emb )
372415
373416 action_tokens = cross_attn (action_tokens , layer_context ) + action_tokens
374417
@@ -422,6 +465,7 @@ def forward(
422465 action_chunk_len = 7 ,
423466 time_seq_len = 4 ,
424467 num_views = 2 ,
468+ num_tasks = 4 ,
425469 add_self_attn = True ,
426470 dim_extra_token = 33 , # extra token with some variable dimension
427471 vit_layer_indices = ( # extending on the paper, allow for any order of hiddens, and also allow for depth index (which equates to the final embedding output from the vit)
@@ -430,15 +474,16 @@ def forward(
430474 )
431475
432476 images = torch .randn (2 , 2 , 3 , 4 , 256 , 256 ) # (2 views with 4 frames)
477+ tasks = torch .randint (0 , 4 , (2 ,))
433478 extra = torch .randn (2 , 33 ) # extra internal state
434479
435480 actions = torch .randn (2 , 7 , 20 ) # actions for learning
436481
437- loss = vat (images , actions = actions , extra = extra )
482+ loss = vat (images , actions = actions , tasks = tasks , extra = extra )
438483 loss .backward ()
439484
440485 # after much training
441486
442- pred_actions = vat (images )
487+ pred_actions = vat (images , tasks = tasks , extra = extra )
443488
444489 assert pred_actions .shape == (2 , 7 , 20 )
0 commit comments