Skip to content

Commit 2587101

Browse files
committed
forgot task conditioning for vat
1 parent e66862b commit 2587101

File tree

2 files changed

+49
-4
lines changed

2 files changed

+49
-4
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
44

55
[project]
66
name = "vit-pytorch"
7-
version = "1.14.1"
7+
version = "1.14.2"
88
description = "Vision Transformer (ViT) - Pytorch"
99
readme = { file = "README.md", content-type = "text/markdown" }
1010
license = { file = "LICENSE" }

vit_pytorch/vat.py

Lines changed: 48 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,27 @@ def pair(t):
2121

2222
# classes
2323

24+
class FiLM(Module):
25+
def __init__(
26+
self,
27+
dim,
28+
):
29+
super().__init__()
30+
proj = nn.Linear(dim, dim * 2)
31+
32+
self.to_gamma_beta = nn.Sequential(
33+
proj,
34+
Rearrange('b (two d) -> two b 1 d', two = 2)
35+
)
36+
37+
nn.init.zeros_(proj.weight)
38+
nn.init.zeros_(proj.bias)
39+
40+
def forward(self, tokens, cond):
41+
gamma, beta = self.to_gamma_beta(cond)
42+
43+
return tokens * gamma + beta
44+
2445
class FeedForward(Module):
2546
def __init__(
2647
self,
@@ -228,6 +249,7 @@ def __init__(
228249
dim_action,
229250
mlp_dim,
230251
num_views = None,
252+
num_tasks = None,
231253
dim_extra_token = None,
232254
action_chunk_len = 7,
233255
time_seq_len = 1,
@@ -266,16 +288,25 @@ def __init__(
266288

267289
self.view_emb = nn.Parameter(torch.randn(num_views, vit_dim) * 1e-2) if exists(num_views) and num_views > 1 else None
268290

291+
# handle maybe task conditioning
292+
293+
self.has_tasks = exists(num_tasks)
294+
295+
if self.has_tasks:
296+
self.task_emb = nn.Parameter(torch.randn(num_tasks, dim) * 1e-2)
297+
269298
# to action tokens
270299

271300
self.action_pos_emb = nn.Parameter(torch.randn(action_chunk_len, dim) * 1e-2)
272301

273302
self.layers = ModuleList([])
274303

275304
for _ in range(depth):
305+
maybe_film = FiLM(dim = dim) if self.has_tasks else None
276306
maybe_self_attn = Attention(dim = dim, heads = self_attn_heads, dim_head = self_attn_dim_head, dropout = dropout) if add_self_attn else None
277307

278308
self.layers.append(ModuleList([
309+
maybe_film,
279310
maybe_self_attn,
280311
Attention(dim = dim, heads = heads, dim_head = dim_head, dropout = dropout, cross_attend = True),
281312
FeedForward(dim = dim, hidden_dim = mlp_dim, dropout = dropout)
@@ -294,7 +325,9 @@ def __init__(
294325
def forward(
295326
self,
296327
video_or_image, # (b v? c t? h w) - batch, views [wrist + third person or more], channels, maybe time, height, width
328+
*,
297329
extra = None, # (b d) - batch, dim extra
330+
tasks = None, # (b)
298331
actions = None, # (b k d) - batch, action chunk length, action dimension
299332
):
300333
batch = video_or_image.shape[0]
@@ -349,6 +382,13 @@ def forward(
349382
view_emb = rearrange(self.view_emb, 'v d -> v 1 1 d')
350383
hiddens = hiddens + view_emb
351384

385+
# maybe tasks
386+
387+
if exists(tasks):
388+
assert self.has_tasks, f'`num_tasks` must be set on `VAT` for task conditioning'
389+
390+
task_emb = self.task_emb[tasks]
391+
352392
# cross from actions to representation trajectory
353393

354394
context = rearrange(hiddens, 'l b v t n d -> l b (v t n) d')
@@ -368,7 +408,10 @@ def forward(
368408

369409
# cross attention
370410

371-
for (maybe_self_attn, cross_attn, ff), layer_context in zip(self.layers, context):
411+
for (maybe_film, maybe_self_attn, cross_attn, ff), layer_context in zip(self.layers, context):
412+
413+
if exists(tasks):
414+
action_tokens = maybe_film(action_tokens, task_emb)
372415

373416
action_tokens = cross_attn(action_tokens, layer_context) + action_tokens
374417

@@ -422,6 +465,7 @@ def forward(
422465
action_chunk_len = 7,
423466
time_seq_len = 4,
424467
num_views = 2,
468+
num_tasks = 4,
425469
add_self_attn = True,
426470
dim_extra_token = 33, # extra token with some variable dimension
427471
vit_layer_indices = ( # extending on the paper, allow for any order of hiddens, and also allow for depth index (which equates to the final embedding output from the vit)
@@ -430,15 +474,16 @@ def forward(
430474
)
431475

432476
images = torch.randn(2, 2, 3, 4, 256, 256) # (2 views with 4 frames)
477+
tasks = torch.randint(0, 4, (2,))
433478
extra = torch.randn(2, 33) # extra internal state
434479

435480
actions = torch.randn(2, 7, 20) # actions for learning
436481

437-
loss = vat(images, actions = actions, extra = extra)
482+
loss = vat(images, actions = actions, tasks = tasks, extra = extra)
438483
loss.backward()
439484

440485
# after much training
441486

442-
pred_actions = vat(images)
487+
pred_actions = vat(images, tasks = tasks, extra = extra)
443488

444489
assert pred_actions.shape == (2, 7, 20)

0 commit comments

Comments
 (0)