Skip to content

Commit 13fabf9

Browse files
committed
add vivit
1 parent c0eb4c0 commit 13fabf9

File tree

4 files changed

+209
-1
lines changed

4 files changed

+209
-1
lines changed

README.md

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
- [Patch Merger](#patch-merger)
3232
- [Vision Transformer for Small Datasets](#vision-transformer-for-small-datasets)
3333
- [3D Vit](#3d-vit)
34+
- [ViVit](#vivit)
3435
- [Parallel ViT](#parallel-vit)
3536
- [Learnable Memory ViT](#learnable-memory-vit)
3637
- [Dino](#dino)
@@ -1022,6 +1023,34 @@ video = torch.randn(4, 3, 16, 128, 128) # (batch, channels, frames, height, widt
10221023
preds = v(video) # (4, 1000)
10231024
```
10241025

1026+
## ViViT
1027+
1028+
<img src="./images/vivit.png" width="350px"></img>
1029+
1030+
This <a href="https://arxiv.org/abs/2103.15691">paper</a> offers 3 different types of architectures for efficient attention of videos, with the main theme being factorizing the attention across space and time. This repository will offer the first variant, which is a spatial transformer followed by a temporal one.
1031+
1032+
```python
1033+
import torch
1034+
from vit_pytorch.vivit import ViT
1035+
1036+
v = ViT(
1037+
image_size = 128, # image size
1038+
frames = 16, # number of frames
1039+
image_patch_size = 16, # image patch size
1040+
frame_patch_size = 2, # frame patch size
1041+
num_classes = 1000,
1042+
dim = 1024,
1043+
spatial_depth = 6, # depth of the spatial transformer
1044+
temporal_depth = 6, # depth of the temporal transformer
1045+
heads = 8,
1046+
mlp_dim = 2048
1047+
)
1048+
1049+
video = torch.randn(4, 3, 16, 128, 128) # (batch, channels, frames, height, width)
1050+
1051+
preds = v(video) # (4, 1000)
1052+
```
1053+
10251054
## Parallel ViT
10261055

10271056
<img src="./images/parallel-vit.png" width="350px"></img>
@@ -1805,6 +1834,16 @@ Coming from computer vision and new to transformers? Here are some resources tha
18051834
18061835
```
18071836

1837+
```bibtex
1838+
@article{Arnab2021ViViTAV,
1839+
title = {ViViT: A Video Vision Transformer},
1840+
author = {Anurag Arnab and Mostafa Dehghani and Georg Heigold and Chen Sun and Mario Lucic and Cordelia Schmid},
1841+
journal = {2021 IEEE/CVF International Conference on Computer Vision (ICCV)},
1842+
year = {2021},
1843+
pages = {6816-6826}
1844+
}
1845+
```
1846+
18081847
```bibtex
18091848
@misc{vaswani2017attention,
18101849
title = {Attention Is All You Need},

images/vivit.png

104 KB
Loading

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'vit-pytorch',
55
packages = find_packages(exclude=['examples']),
6-
version = '0.36.2',
6+
version = '0.37.0',
77
license='MIT',
88
description = 'Vision Transformer (ViT) - Pytorch',
99
long_description_content_type = 'text/markdown',

vit_pytorch/vivit.py

Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
import torch
2+
from torch import nn
3+
4+
from einops import rearrange, repeat
5+
from einops.layers.torch import Rearrange
6+
7+
# helpers
8+
9+
def pair(t):
10+
return t if isinstance(t, tuple) else (t, t)
11+
12+
# classes
13+
14+
class PreNorm(nn.Module):
15+
def __init__(self, dim, fn):
16+
super().__init__()
17+
self.norm = nn.LayerNorm(dim)
18+
self.fn = fn
19+
def forward(self, x, **kwargs):
20+
return self.fn(self.norm(x), **kwargs)
21+
22+
class FeedForward(nn.Module):
23+
def __init__(self, dim, hidden_dim, dropout = 0.):
24+
super().__init__()
25+
self.net = nn.Sequential(
26+
nn.Linear(dim, hidden_dim),
27+
nn.GELU(),
28+
nn.Dropout(dropout),
29+
nn.Linear(hidden_dim, dim),
30+
nn.Dropout(dropout)
31+
)
32+
def forward(self, x):
33+
return self.net(x)
34+
35+
class Attention(nn.Module):
36+
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
37+
super().__init__()
38+
inner_dim = dim_head * heads
39+
project_out = not (heads == 1 and dim_head == dim)
40+
41+
self.heads = heads
42+
self.scale = dim_head ** -0.5
43+
44+
self.attend = nn.Softmax(dim = -1)
45+
self.dropout = nn.Dropout(dropout)
46+
47+
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
48+
49+
self.to_out = nn.Sequential(
50+
nn.Linear(inner_dim, dim),
51+
nn.Dropout(dropout)
52+
) if project_out else nn.Identity()
53+
54+
def forward(self, x):
55+
qkv = self.to_qkv(x).chunk(3, dim = -1)
56+
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
57+
58+
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
59+
60+
attn = self.attend(dots)
61+
attn = self.dropout(attn)
62+
63+
out = torch.matmul(attn, v)
64+
out = rearrange(out, 'b h n d -> b n (h d)')
65+
return self.to_out(out)
66+
67+
class Transformer(nn.Module):
68+
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
69+
super().__init__()
70+
self.layers = nn.ModuleList([])
71+
for _ in range(depth):
72+
self.layers.append(nn.ModuleList([
73+
PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
74+
PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
75+
]))
76+
def forward(self, x):
77+
for attn, ff in self.layers:
78+
x = attn(x) + x
79+
x = ff(x) + x
80+
return x
81+
82+
class ViT(nn.Module):
83+
def __init__(
84+
self,
85+
*,
86+
image_size,
87+
image_patch_size,
88+
frames,
89+
frame_patch_size,
90+
num_classes,
91+
dim,
92+
spatial_depth,
93+
temporal_depth,
94+
heads,
95+
mlp_dim,
96+
pool = 'cls',
97+
channels = 3,
98+
dim_head = 64,
99+
dropout = 0.,
100+
emb_dropout = 0.
101+
):
102+
super().__init__()
103+
image_height, image_width = pair(image_size)
104+
patch_height, patch_width = pair(image_patch_size)
105+
106+
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
107+
assert frames % frame_patch_size == 0, 'Frames must be divisible by frame patch size'
108+
109+
num_patches = (image_height // patch_height) * (image_width // patch_width) * (frames // frame_patch_size)
110+
patch_dim = channels * patch_height * patch_width * frame_patch_size
111+
112+
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
113+
114+
self.to_patch_embedding = nn.Sequential(
115+
Rearrange('b c (f pf) (h p1) (w p2) -> b f (h w) (p1 p2 pf c)', p1 = patch_height, p2 = patch_width, pf = frame_patch_size),
116+
nn.Linear(patch_dim, dim),
117+
)
118+
119+
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
120+
self.dropout = nn.Dropout(emb_dropout)
121+
self.spatial_cls_token = nn.Parameter(torch.randn(1, 1, dim))
122+
self.temporal_cls_token = nn.Parameter(torch.randn(1, 1, dim))
123+
124+
self.spatial_transformer = Transformer(dim, spatial_depth, heads, dim_head, mlp_dim, dropout)
125+
self.temporal_transformer = Transformer(dim, temporal_depth, heads, dim_head, mlp_dim, dropout)
126+
127+
self.pool = pool
128+
self.to_latent = nn.Identity()
129+
130+
self.mlp_head = nn.Sequential(
131+
nn.LayerNorm(dim),
132+
nn.Linear(dim, num_classes)
133+
)
134+
135+
def forward(self, img):
136+
x = self.to_patch_embedding(img)
137+
b, f, n, _ = x.shape
138+
139+
spatial_cls_tokens = repeat(self.spatial_cls_token, '1 1 d -> b f 1 d', b = b, f = f)
140+
x = torch.cat((spatial_cls_tokens, x), dim = 2)
141+
x += self.pos_embedding[:, :(n + 1)]
142+
x = self.dropout(x)
143+
144+
x = rearrange(x, 'b f n d -> (b f) n d')
145+
146+
# attend across space
147+
148+
x = self.spatial_transformer(x)
149+
150+
x = rearrange(x, '(b f) n d -> b f n d', b = b)
151+
152+
# excise out the spatial cls tokens for temporal attention
153+
154+
x = x[:, :, 0]
155+
156+
# append temporal CLS tokens
157+
158+
temporal_cls_tokens = repeat(self.temporal_cls_token, '1 1 d-> b 1 d', b = b)
159+
160+
x = torch.cat((temporal_cls_tokens, x), dim = 1)
161+
162+
# attend across time
163+
164+
x = self.temporal_transformer(x)
165+
166+
x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
167+
168+
x = self.to_latent(x)
169+
return self.mlp_head(x)

0 commit comments

Comments
 (0)